diff --git a/memgpt/agent.py b/memgpt/agent.py index 33dff5b4..607d15fb 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -411,12 +411,14 @@ class Agent(object): response_message["tool_call_id"] = tool_call_id # role: assistant (requesting tool call, set tool call ID) messages.append(response_message) # extend conversation with assistant's reply + printd(f"Function call message: {messages[-1]}") # Step 3: call the function # Note: the JSON response may not always be valid; be sure to handle errors # Failure case 1: function name is wrong function_name = response_message["function_call"]["name"] + printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}") try: function_to_call = self.functions_python[function_name] except KeyError as e: diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 2538eb9e..5e32f066 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -26,7 +26,7 @@ from memgpt.connectors.storage import StorageConnector, TableType from memgpt.config import AgentConfig, MemGPTConfig from memgpt.constants import MEMGPT_DIR from memgpt.utils import printd -from memgpt.data_types import Record, Message, Passage, Source +from memgpt.data_types import Record, Message, Passage, Source, ToolCall from datetime import datetime @@ -71,6 +71,26 @@ class CommonVector(TypeDecorator): return np.array(list_value) +class ToolCalls(TypeDecorator): + + """Custom type for storing List[ToolCall] as JSON""" + + impl = JSON + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + return [vars(v) for v in value] + return value + + def process_result_value(self, value, dialect): + if value: + return [ToolCall(**v) for v in value] + return value + + Base = declarative_base() @@ -155,8 +175,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): # if role == "assistant", this MAY be specified # if role != "assistant", this must be null # TODO align with OpenAI spec of multiple tool calls - tool_name = Column(String) - tool_args = Column(String) + tool_calls = Column(ToolCalls) # tool call response info # if role == "tool", then this must be specified @@ -185,8 +204,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"): user=self.user, text=self.text, model=self.model, - tool_name=self.tool_name, - tool_args=self.tool_args, + tool_calls=self.tool_calls, tool_call_id=self.tool_call_id, embedding=self.embedding, created_at=self.created_at, diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 4e0453ad..d2768bd7 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -1,7 +1,7 @@ """ This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """ import uuid from abc import abstractmethod -from typing import Optional +from typing import Optional, List, Dict import numpy as np @@ -24,8 +24,28 @@ class Record: assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type" +class ToolCall(object): + def __init__( + self, + id: str, + # TODO should we include this? it's fixed to 'function' only (for now) in OAI schema + tool_call_type: str, # only 'function' is supported + # function: { 'name': ..., 'arguments': ...} + function: Dict[str, str], + ): + self.id = id + self.tool_call_type = tool_call_type + self.function = function + + class Message(Record): - """Representation of a message sent from the agent -> user. Also includes function calls.""" + """Representation of a message sent. + + Messages can be: + - agent->user (role=='agent') + - user->agent and system->agent (role=='user') + - or function/tool call returns (role=='function'/'tool'). + """ def __init__( self, @@ -36,9 +56,8 @@ class Message(Record): model: str, # model used to make function call user: Optional[str] = None, # optional participant name created_at: Optional[str] = None, - tool_name: Optional[str] = None, # name of tool used - tool_args: Optional[str] = None, # args of tool used - tool_call_id: Optional[str] = None, # id of tool call + tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested + tool_call_id: Optional[str] = None, embedding: Optional[np.ndarray] = None, id: Optional[str] = None, ): @@ -54,8 +73,13 @@ class Message(Record): self.user = user # tool (i.e. function) call info (optional) - self.tool_name = tool_name - self.tool_args = tool_args + + # if role == "assistant", this MAY be specified + # if role != "assistant", this must be null + self.tool_calls = tool_calls + + # if role == "tool", then this must be specified + # if role != "tool", this must be null self.tool_call_id = tool_call_id # embedding (optional) diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index b42f5efe..ad7621cc 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -7,7 +7,7 @@ from memgpt.memory import ( EmbeddingArchivalMemory, ) from memgpt.utils import get_local_time, printd -from memgpt.data_types import Message +from memgpt.data_types import Message, ToolCall from memgpt.config import MemGPTConfig from datetime import datetime @@ -116,6 +116,22 @@ class LocalStateManager(PersistenceManager): timestamp = message_json["timestamp"] message = message_json["message"] + # TODO: change this when we fully migrate to tool calls API + if "function_call" in message: + tool_calls = [ + ToolCall( + id=message["tool_call_id"], + tool_call_type="function", + function={ + "name": message["function_call"]["name"], + "arguments": message["function_call"]["arguments"], + }, + ) + ] + printd(f"Saving tool calls {[vars(tc) for tc in tool_calls]}") + else: + tool_calls = None + return Message( user_id=self.config.anon_clientid, agent_id=self.agent_config.name, @@ -123,8 +139,7 @@ class LocalStateManager(PersistenceManager): text=message["content"], model=self.agent_config.model, created_at=parse_formatted_time(timestamp), - tool_name=message["function_name"] if "function_name" in message else None, - tool_args=message["function_args"] if "function_args" in message else None, + tool_calls=tool_calls, tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None, id=message["id"] if "id" in message else None, )