From 71805b2a22e637df72da5fbdb487ffb337fdbd4f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 25 Feb 2025 15:13:35 -0800 Subject: [PATCH] feat: Add tool calling to fast chat completions (#1109) --- letta/constants.py | 7 + letta/helpers/composio_helpers.py | 5 +- letta/helpers/tool_execution_helper.py | 171 ++++++++++++ .../schemas/openai/chat_completion_request.py | 2 +- .../chat_completions/chat_completions.py | 254 +++++++++++++++--- letta/server/rest_api/utils.py | 5 +- letta/services/agent_manager.py | 36 +++ letta/services/block_manager.py | 6 +- letta/services/message_manager.py | 5 +- tests/integration_test_chat_completions.py | 59 +++- 10 files changed, 507 insertions(+), 43 deletions(-) create mode 100644 letta/helpers/tool_execution_helper.py diff --git a/letta/constants.py b/letta/constants.py index 35ab7cb4..468afa4c 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -52,6 +52,8 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", " BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] # Multi agent tools MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"] +# Set of all built-in Letta tools +LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS) # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) @@ -59,6 +61,11 @@ MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_t DEFAULT_MESSAGE_TOOL = "send_message" DEFAULT_MESSAGE_TOOL_KWARG = "message" +PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg" + +REQUEST_HEARTBEAT_PARAM = "request_heartbeat" + + # Structured output models STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"} diff --git a/letta/helpers/composio_helpers.py b/letta/helpers/composio_helpers.py index 8a8c3249..a3c518ec 100644 --- a/letta/helpers/composio_helpers.py +++ b/letta/helpers/composio_helpers.py @@ -6,10 +6,11 @@ from letta.services.sandbox_config_manager import SandboxConfigManager from letta.settings import tool_settings -def get_composio_api_key(actor: User, logger: Logger) -> Optional[str]: +def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Optional[str]: api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor) if not api_keys: - logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...") + if logger: + logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...") if tool_settings.composio_api_key: return tool_settings.composio_api_key else: diff --git a/letta/helpers/tool_execution_helper.py b/letta/helpers/tool_execution_helper.py new file mode 100644 index 00000000..948772ee --- /dev/null +++ b/letta/helpers/tool_execution_helper.py @@ -0,0 +1,171 @@ +from collections import OrderedDict +from typing import Any, Dict, Optional + +from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source +from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name +from letta.helpers.composio_helpers import get_composio_api_key +from letta.orm.enums import ToolType +from letta.schemas.agent import AgentState +from letta.schemas.sandbox_config import SandboxRunResult +from letta.schemas.tool import Tool +from letta.schemas.user import User +from letta.services.tool_execution_sandbox import ToolExecutionSandbox +from letta.utils import get_friendly_error_msg + + +def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]: + """Enables strict mode for a tool schema by setting 'strict' to True and + disallowing additional properties in the parameters. + + Args: + tool_schema (Dict[str, Any]): The original tool schema. + + Returns: + Dict[str, Any]: A new tool schema with strict mode enabled. + """ + schema = tool_schema.copy() + + # Enable strict mode + schema["strict"] = True + + # Ensure parameters is a valid dictionary + parameters = schema.get("parameters", {}) + + if isinstance(parameters, dict) and parameters.get("type") == "object": + # Set additionalProperties to False + parameters["additionalProperties"] = False + schema["parameters"] = parameters + + return schema + + +def add_pre_execution_message(tool_schema: Dict[str, Any]) -> Dict[str, Any]: + """Adds a `pre_execution_message` parameter to a tool schema to prompt a natural, human-like message before executing the tool. + + Args: + tool_schema (Dict[str, Any]): The original tool schema. + + Returns: + Dict[str, Any]: A new tool schema with the `pre_execution_message` field added at the beginning. + """ + schema = tool_schema.copy() + parameters = schema.get("parameters", {}) + + if not isinstance(parameters, dict) or parameters.get("type") != "object": + return schema # Do not modify if schema is not valid + + properties = parameters.get("properties", {}) + required = parameters.get("required", []) + + # Define the new `pre_execution_message` field with a refined description + pre_execution_message_field = { + "type": "string", + "description": ( + "A concise message to be uttered before executing this tool. " + "This should sound natural, as if a person is casually announcing their next action." + "You MUST also include punctuation at the end of this message." + ), + } + + # Ensure the pre-execution message is the first field in properties + updated_properties = OrderedDict() + updated_properties[PRE_EXECUTION_MESSAGE_ARG] = pre_execution_message_field + updated_properties.update(properties) # Retain all existing properties + + # Ensure pre-execution message is the first required field + if PRE_EXECUTION_MESSAGE_ARG not in required: + required = [PRE_EXECUTION_MESSAGE_ARG] + required + + # Update the schema with ordered properties and required list + schema["parameters"] = { + **parameters, + "properties": dict(updated_properties), # Convert OrderedDict back to dict + "required": required, + } + + return schema + + +def remove_request_heartbeat(tool_schema: Dict[str, Any]) -> Dict[str, Any]: + """Removes the `request_heartbeat` parameter from a tool schema if it exists. + + Args: + tool_schema (Dict[str, Any]): The original tool schema. + + Returns: + Dict[str, Any]: A new tool schema without `request_heartbeat`. + """ + schema = tool_schema.copy() + parameters = schema.get("parameters", {}) + + if isinstance(parameters, dict): + properties = parameters.get("properties", {}) + required = parameters.get("required", []) + + # Remove the `request_heartbeat` property if it exists + if "request_heartbeat" in properties: + properties.pop("request_heartbeat") + + # Remove `request_heartbeat` from required fields if present + if "request_heartbeat" in required: + required = [r for r in required if r != "request_heartbeat"] + + # Update parameters with modified properties and required list + schema["parameters"] = {**parameters, "properties": properties, "required": required} + + return schema + + +# TODO: Deprecate the `execute_external_tool` function on the agent body +def execute_external_tool( + agent_state: AgentState, + function_name: str, + function_args: dict, + target_letta_tool: Tool, + actor: User, + allow_agent_state_modifications: bool = False, +) -> tuple[Any, Optional[SandboxRunResult]]: + # TODO: need to have an AgentState object that actually has full access to the block data + # this is because the sandbox tools need to be able to access block.value to edit this data + try: + if target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO: + action_name = generate_composio_action_from_func_name(target_letta_tool.name) + # Get entity ID from the agent_state + entity_id = None + for env_var in agent_state.tool_exec_environment_variables: + if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY: + entity_id = env_var.value + # Get composio_api_key + composio_api_key = get_composio_api_key(actor=actor) + function_response = execute_composio_action( + action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id + ) + return function_response, None + elif target_letta_tool.tool_type == ToolType.CUSTOM: + # Parse the source code to extract function annotations + annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) + # Coerce the function arguments to the correct types based on the annotations + function_args = coerce_dict_args_by_annotations(function_args, annotations) + + # execute tool in a sandbox + # TODO: allow agent_state to specify which sandbox to execute tools in + # TODO: This is only temporary, can remove after we publish a pip package with this object + if allow_agent_state_modifications: + agent_state_copy = agent_state.__deepcopy__() + agent_state_copy.tools = [] + agent_state_copy.tool_rules = [] + else: + agent_state_copy = None + + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy) + function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state + # TODO: Bring this back + # if allow_agent_state_modifications and updated_agent_state is not None: + # self.update_memory_if_changed(updated_agent_state.memory) + return function_response, sandbox_run_result + except Exception as e: + # Need to catch error here, or else trunction wont happen + # TODO: modify to function execution error + function_response = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) + return function_response, None diff --git a/letta/schemas/openai/chat_completion_request.py b/letta/schemas/openai/chat_completion_request.py index 5b7b2743..12486bca 100644 --- a/letta/schemas/openai/chat_completion_request.py +++ b/letta/schemas/openai/chat_completion_request.py @@ -99,7 +99,7 @@ class ChatCompletionRequest(BaseModel): """https://platform.openai.com/docs/api-reference/chat/create""" model: str - messages: List[ChatMessage] + messages: List[Union[ChatMessage, Dict]] frequency_penalty: Optional[float] = 0 logit_bias: Optional[Dict[str, int]] = None logprobs: Optional[bool] = False diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 13fd2347..428dbbd9 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -1,19 +1,39 @@ import asyncio +import json +import uuid from typing import TYPE_CHECKING, List, Optional, Union import httpx import openai from fastapi import APIRouter, Body, Depends, Header, HTTPException from fastapi.responses import StreamingResponse +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta from openai.types.chat.completion_create_params import CompletionCreateParams from starlette.concurrency import run_in_threadpool from letta.agent import Agent -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_TOOL_SET, NON_USER_MSG_PREFIX, PRE_EXECUTION_MESSAGE_ARG +from letta.helpers.tool_execution_helper import ( + add_pre_execution_message, + enable_strict_mode, + execute_external_tool, + remove_request_heartbeat, +) from letta.log import get_logger +from letta.orm.enums import ToolType from letta.schemas.message import Message, MessageCreate +from letta.schemas.openai.chat_completion_request import ( + AssistantMessage, + ChatCompletionRequest, + Tool, + ToolCall, + ToolCallFunction, + ToolMessage, + UserMessage, +) from letta.schemas.user import User from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface +from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser # TODO this belongs in a controller! from letta.server.rest_api.utils import ( @@ -52,20 +72,53 @@ async def create_fast_chat_completions( server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): - # TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior + actor = server.user_manager.get_user_or_default(user_id=user_id) + agent_id = str(completion_request.get("user", None)) if agent_id is None: - error_msg = "Must pass agent_id in the 'user' field" - logger.error(error_msg) - raise HTTPException(status_code=400, detail=error_msg) - model = completion_request.get("model") + raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field") - actor = server.user_manager.get_user_or_default(user_id=user_id) + agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + if agent_state.llm_config.model_endpoint_type != "openai": + raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.") + + # Convert Letta messages to OpenAI messages + in_context_messages = server.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor) + openai_messages = convert_letta_messages_to_openai(in_context_messages) + + # Also parse user input from completion_request and append + input_message = get_messages_from_completion_request(completion_request)[-1] + openai_messages.append(input_message) + + # Tools we allow this agent to call + tools = [t for t in agent_state.tools if t.name not in LETTA_TOOL_SET and t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}] + + # Initial request + openai_request = ChatCompletionRequest( + model=agent_state.llm_config.model, + messages=openai_messages, + # TODO: This nested thing here is so ugly, need to refactor + tools=( + [ + Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema)))) + for t in tools + ] + if tools + else None + ), + tool_choice="auto", + user=user_id, + max_completion_tokens=agent_state.llm_config.max_tokens, + temperature=agent_state.llm_config.temperature, + stream=True, + ) + + # Create the OpenAI async client client = openai.AsyncClient( api_key=model_settings.openai_api_key, max_retries=0, http_client=httpx.AsyncClient( - timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0), + timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0), follow_redirects=True, limits=httpx.Limits( max_connections=50, @@ -75,38 +128,175 @@ async def create_fast_chat_completions( ), ) - # Magic message manipulating - input_message = get_messages_from_completion_request(completion_request)[-1] - completion_request.pop("messages") - - # Get in context messages - in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor) - openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages) - openai_dict_in_context_messages.append(input_message) + # The messages we want to persist to the Letta agent + user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor) + message_db_queue = [user_message] async def event_stream(): - # TODO: Factor this out into separate interface - response_accumulator = [] + """ + A function-calling loop: + - We stream partial tokens. + - If we detect a tool call (finish_reason="tool_calls"), we parse it, + add two messages to the conversation: + (a) assistant message with tool_calls referencing the same ID + (b) a tool message referencing that ID, containing the tool result. + - Re-invoke the OpenAI request with updated conversation, streaming again. + - End when finish_reason="stop" or no more tool calls. + """ - stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages) + # We'll keep updating this conversation in a loop + conversation = openai_messages[:] - async with stream: - async for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content: - # TODO: This does not support tool calling right now - response_accumulator.append(chunk.choices[0].delta.content) - yield f"data: {chunk.model_dump_json()}\n\n" + while True: + # Make the streaming request to OpenAI + stream = await client.chat.completions.create(**openai_request.model_dump(exclude_unset=True)) - # Construct messages - user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor) - assistant_message = create_assistant_message_from_openai_response( - response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor - ) + content_buffer = [] + tool_call_name = None + tool_call_args_str = "" + tool_call_id = None + tool_call_happened = False + finish_reason_stop = False + optimistic_json_parser = OptimisticJSONParser(strict=True) + current_parsed_json_result = {} + + async with stream: + async for chunk in stream: + choice = chunk.choices[0] + delta = choice.delta + finish_reason = choice.finish_reason # "tool_calls", "stop", or None + + if delta.content: + content_buffer.append(delta.content) + yield f"data: {chunk.model_dump_json()}\n\n" + + # CASE B: Partial tool call info + if delta.tool_calls: + # Typically there's only one in delta.tool_calls + tc = delta.tool_calls[0] + if tc.function.name: + tool_call_name = tc.function.name + if tc.function.arguments: + tool_call_args_str += tc.function.arguments + + # See if we can stream out the pre-execution message + parsed_args = optimistic_json_parser.parse(tool_call_args_str) + if parsed_args.get( + PRE_EXECUTION_MESSAGE_ARG + ) and current_parsed_json_result.get( # Ensure key exists and is not None/empty + PRE_EXECUTION_MESSAGE_ARG + ) != parsed_args.get( + PRE_EXECUTION_MESSAGE_ARG + ): + # Only stream if there's something new to stream + # We do this way to avoid hanging JSON at the end of the stream, e.g. '}' + if parsed_args != current_parsed_json_result: + current_parsed_json_result = parsed_args + synthetic_chunk = ChatCompletionChunk( + id=chunk.id, + object=chunk.object, + created=chunk.created, + model=chunk.model, + choices=[ + Choice( + index=choice.index, + delta=ChoiceDelta(content=tc.function.arguments, role="assistant"), + finish_reason=None, + ) + ], + ) + + yield f"data: {synthetic_chunk.model_dump_json()}\n\n" + + # We might generate a unique ID for the tool call + if tc.id: + tool_call_id = tc.id + + # Check finish_reason + if finish_reason == "tool_calls": + tool_call_happened = True + break + elif finish_reason == "stop": + finish_reason_stop = True + break + + if content_buffer: + # We treat that partial text as an assistant message + content = "".join(content_buffer) + conversation.append({"role": "assistant", "content": content}) + + # Create an assistant message here to persist later + assistant_message = create_assistant_message_from_openai_response( + response_text=content, agent_id=agent_id, model=agent_state.llm_config.model, actor=actor + ) + message_db_queue.append(assistant_message) + + if tool_call_happened: + # Parse the tool call arguments + try: + tool_args = json.loads(tool_call_args_str) + except json.JSONDecodeError: + tool_args = {} + + if not tool_call_id: + # If no tool_call_id given by the model, generate one + tool_call_id = f"call_{uuid.uuid4().hex[:8]}" + + # 1) Insert the "assistant" message with the tool_calls field + # referencing the same tool_call_id + assistant_tool_call_msg = AssistantMessage( + content=None, + tool_calls=[ToolCall(id=tool_call_id, function=ToolCallFunction(name=tool_call_name, arguments=tool_call_args_str))], + ) + + conversation.append(assistant_tool_call_msg.model_dump()) + + # 2) Execute the tool + target_tool = next((x for x in tools if x.name == tool_call_name), None) + if not target_tool: + # Tool not found, handle error + yield f"data: {json.dumps({'error': 'Tool not found', 'tool': tool_call_name})}\n\n" + break + + try: + tool_result, _ = execute_external_tool( + agent_state=agent_state, + function_name=tool_call_name, + function_args=tool_args, + target_letta_tool=target_tool, + actor=actor, + allow_agent_state_modifications=False, + ) + except Exception as e: + tool_result = f"Failed to call tool. Error: {e}" + + # 3) Insert the "tool" message referencing the same tool_call_id + tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id) + + conversation.append(tool_message.model_dump()) + + # 4) Add a user message prompting the tool call result summarization + heartbeat_user_message = UserMessage( + content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user.", + ) + conversation.append(heartbeat_user_message.model_dump()) + + # Now, re-invoke OpenAI with the updated conversation + openai_request.messages = conversation + + continue # Start the while loop again + + if finish_reason_stop: + # Model is done, no more calls + break + + # If we reach here, no tool call, no "stop", but we've ended streaming + # Possibly a model error or some other finish reason. We'll just end. + break - # Persist both in one synchronous DB call, done in a threadpool await run_in_threadpool( server.agent_manager.append_to_in_context_messages, - [user_message, assistant_message], + message_db_queue, agent_id=agent_id, actor=actor, ) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index d5bf4520..8008d056 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -7,7 +7,6 @@ from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast -import pytz from fastapi import Header, HTTPException from openai.types.chat import ChatCompletionMessageParam from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall @@ -145,7 +144,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess Converts a user input message into the internal structured format. """ # Generate timestamp in the correct format - now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z") + now = datetime.now(timezone.utc).isoformat() # Format message as structured JSON structured_message = {"type": "user_message", "message": input_message["content"], "time": now} @@ -197,7 +196,7 @@ def create_assistant_message_from_openai_response( agent_id=agent_id, model=model, tool_calls=[tool_call], - tool_call_id=None, + tool_call_id=tool_call_id, created_at=datetime.now(timezone.utc), ) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index fb450d3d..8d9743ea 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -21,8 +21,10 @@ from letta.orm.sqlite_functions import adapt_array from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock +from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.memory import Memory from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate from letta.schemas.passage import Passage as PydanticPassage @@ -613,6 +615,40 @@ class AgentManager: ) return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor) + # TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version + @enforce_types + def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState: + """ + Update internal memory object and system prompt if there have been modifications. + + Args: + new_memory (Memory): the new memory object to compare to the current memory object + + Returns: + modified (bool): whether the memory was updated + """ + agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor) + if agent_state.memory.compile() != new_memory.compile(): + # update the blocks (LRW) in the DB + for label in agent_state.memory.list_block_labels(): + updated_value = new_memory.get_block(label).value + if updated_value != agent_state.memory.get_block(label).value: + # update the block if it's changed + block_id = agent_state.memory.get_block(label).id + block = self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor) + + # refresh memory from DB (using block ids) + agent_state.memory = Memory( + blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()] + ) + + # NOTE: don't do this since re-buildin the memory is handled at the start of the step + # rebuild memory - this records the last edited timestamp of the memory + # TODO: pass in update timestamp from block edit time + agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor) + + return agent_state + # ====================================================================================================================== # Source Management # ====================================================================================================================== diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 7ae743a7..fe10671d 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -107,12 +107,14 @@ class BlockManager: @enforce_types def add_default_blocks(self, actor: PydanticUser): for persona_file in list_persona_files(): - text = open(persona_file, "r", encoding="utf-8").read() + with open(persona_file, "r", encoding="utf-8") as f: + text = f.read() name = os.path.basename(persona_file).replace(".txt", "") self.create_or_update_block(Persona(template_name=name, value=text, is_template=True), actor=actor) for human_file in list_human_files(): - text = open(human_file, "r", encoding="utf-8").read() + with open(human_file, "r", encoding="utf-8") as f: + text = f.read() name = os.path.basename(human_file).replace(".txt", "") self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index ed2881b3..26f0bee5 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -2,6 +2,7 @@ from typing import List, Optional from sqlalchemy import and_, or_ +from letta.log import get_logger from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel @@ -11,6 +12,8 @@ from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types +logger = get_logger(__name__) + class MessageManager: """Manager class to handle business logic related to Messages.""" @@ -37,7 +40,7 @@ class MessageManager: results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids)) if len(results) != len(message_ids): - raise NoResultFound( + logger.warning( f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}" ) diff --git a/tests/integration_test_chat_completions.py b/tests/integration_test_chat_completions.py index 4ab3b1d8..4b501de3 100644 --- a/tests/integration_test_chat_completions.py +++ b/tests/integration_test_chat_completions.py @@ -14,7 +14,9 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageStreamStatus from letta.schemas.llm_config import LLMConfig from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage +from letta.schemas.tool import ToolCreate from letta.schemas.usage import LettaUsageStatistics +from letta.services.tool_manager import ToolManager # --- Server Management --- # @@ -69,9 +71,49 @@ def roll_dice_tool(client): @pytest.fixture(scope="function") -def agent(client, roll_dice_tool): +def weather_tool(client): + def get_weather(location: str) -> str: + """ + Fetches the current weather for a given location. + + Parameters: + location (str): The location to get the weather for. + + Returns: + str: A formatted string describing the weather in the given location. + + Raises: + RuntimeError: If the request to fetch weather data fails. + """ + import requests + + url = f"https://wttr.in/{location}?format=%C+%t" + + response = requests.get(url) + if response.status_code == 200: + weather_data = response.text + return f"The weather in {location} is {weather_data}." + else: + raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") + + tool = client.create_or_update_tool(func=get_weather) + # Yield the created tool + yield tool + + +@pytest.fixture(scope="function") +def composio_gmail_get_profile_tool(default_user): + tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE") + tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user) + yield tool + + +@pytest.fixture(scope="function") +def agent(client, roll_dice_tool, weather_tool, composio_gmail_get_profile_tool): """Creates an agent and ensures cleanup after tests.""" - agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id]) + agent_state = client.create_agent( + name=f"test_compl_{str(uuid.uuid4())[5:]}", tool_ids=[roll_dice_tool.id, weather_tool.id, composio_gmail_get_profile_tool.id] + ) yield agent_state client.delete_agent(agent_state.id) @@ -111,6 +153,19 @@ def _assert_valid_chunk(chunk, idx, chunks): # --- Test Cases --- # +@pytest.mark.parametrize("message", ["What's the weather in SF?"]) +@pytest.mark.parametrize("endpoint", ["fast/chat/completions"]) +def test_tool_usage_fast_chat_completions(mock_e2b_api_key_none, client, agent, message, endpoint): + """Tests chat completion streaming via SSE.""" + request = _get_chat_request(agent.id, message) + + response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers) + + for chunk in response: + if isinstance(chunk, ChatCompletionChunk) and chunk.choices: + print(chunk.choices[0].delta.content) + + @pytest.mark.parametrize("message", ["Tell me something interesting about bananas."]) @pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"]) def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint):