fix: Update message schema / data type to match OAI tools style (#783)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
@@ -411,12 +411,14 @@ class Agent(object):
|
||||
response_message["tool_call_id"] = tool_call_id
|
||||
# role: assistant (requesting tool call, set tool call ID)
|
||||
messages.append(response_message) # extend conversation with assistant's reply
|
||||
printd(f"Function call message: {messages[-1]}")
|
||||
|
||||
# Step 3: call the function
|
||||
# Note: the JSON response may not always be valid; be sure to handle errors
|
||||
|
||||
# Failure case 1: function name is wrong
|
||||
function_name = response_message["function_call"]["name"]
|
||||
printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
|
||||
try:
|
||||
function_to_call = self.functions_python[function_name]
|
||||
except KeyError as e:
|
||||
|
||||
@@ -26,7 +26,7 @@ from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.utils import printd
|
||||
from memgpt.data_types import Record, Message, Passage, Source
|
||||
from memgpt.data_types import Record, Message, Passage, Source, ToolCall
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
@@ -71,6 +71,26 @@ class CommonVector(TypeDecorator):
|
||||
return np.array(list_value)
|
||||
|
||||
|
||||
class ToolCalls(TypeDecorator):
|
||||
|
||||
"""Custom type for storing List[ToolCall] as JSON"""
|
||||
|
||||
impl = JSON
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
return [vars(v) for v in value]
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return [ToolCall(**v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@@ -155,8 +175,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"):
|
||||
# if role == "assistant", this MAY be specified
|
||||
# if role != "assistant", this must be null
|
||||
# TODO align with OpenAI spec of multiple tool calls
|
||||
tool_name = Column(String)
|
||||
tool_args = Column(String)
|
||||
tool_calls = Column(ToolCalls)
|
||||
|
||||
# tool call response info
|
||||
# if role == "tool", then this must be specified
|
||||
@@ -185,8 +204,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"):
|
||||
user=self.user,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
tool_name=self.tool_name,
|
||||
tool_args=self.tool_args,
|
||||
tool_calls=self.tool_calls,
|
||||
tool_call_id=self.tool_call_id,
|
||||
embedding=self.embedding,
|
||||
created_at=self.created_at,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -24,8 +24,28 @@ class Record:
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
|
||||
|
||||
class ToolCall(object):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
# TODO should we include this? it's fixed to 'function' only (for now) in OAI schema
|
||||
tool_call_type: str, # only 'function' is supported
|
||||
# function: { 'name': ..., 'arguments': ...}
|
||||
function: Dict[str, str],
|
||||
):
|
||||
self.id = id
|
||||
self.tool_call_type = tool_call_type
|
||||
self.function = function
|
||||
|
||||
|
||||
class Message(Record):
|
||||
"""Representation of a message sent from the agent -> user. Also includes function calls."""
|
||||
"""Representation of a message sent.
|
||||
|
||||
Messages can be:
|
||||
- agent->user (role=='agent')
|
||||
- user->agent and system->agent (role=='user')
|
||||
- or function/tool call returns (role=='function'/'tool').
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -36,9 +56,8 @@ class Message(Record):
|
||||
model: str, # model used to make function call
|
||||
user: Optional[str] = None, # optional participant name
|
||||
created_at: Optional[str] = None,
|
||||
tool_name: Optional[str] = None, # name of tool used
|
||||
tool_args: Optional[str] = None, # args of tool used
|
||||
tool_call_id: Optional[str] = None, # id of tool call
|
||||
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
|
||||
tool_call_id: Optional[str] = None,
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
id: Optional[str] = None,
|
||||
):
|
||||
@@ -54,8 +73,13 @@ class Message(Record):
|
||||
self.user = user
|
||||
|
||||
# tool (i.e. function) call info (optional)
|
||||
self.tool_name = tool_name
|
||||
self.tool_args = tool_args
|
||||
|
||||
# if role == "assistant", this MAY be specified
|
||||
# if role != "assistant", this must be null
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
# if role == "tool", then this must be specified
|
||||
# if role != "tool", this must be null
|
||||
self.tool_call_id = tool_call_id
|
||||
|
||||
# embedding (optional)
|
||||
|
||||
@@ -7,7 +7,7 @@ from memgpt.memory import (
|
||||
EmbeddingArchivalMemory,
|
||||
)
|
||||
from memgpt.utils import get_local_time, printd
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.data_types import Message, ToolCall
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
from datetime import datetime
|
||||
@@ -116,6 +116,22 @@ class LocalStateManager(PersistenceManager):
|
||||
timestamp = message_json["timestamp"]
|
||||
message = message_json["message"]
|
||||
|
||||
# TODO: change this when we fully migrate to tool calls API
|
||||
if "function_call" in message:
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=message["tool_call_id"],
|
||||
tool_call_type="function",
|
||||
function={
|
||||
"name": message["function_call"]["name"],
|
||||
"arguments": message["function_call"]["arguments"],
|
||||
},
|
||||
)
|
||||
]
|
||||
printd(f"Saving tool calls {[vars(tc) for tc in tool_calls]}")
|
||||
else:
|
||||
tool_calls = None
|
||||
|
||||
return Message(
|
||||
user_id=self.config.anon_clientid,
|
||||
agent_id=self.agent_config.name,
|
||||
@@ -123,8 +139,7 @@ class LocalStateManager(PersistenceManager):
|
||||
text=message["content"],
|
||||
model=self.agent_config.model,
|
||||
created_at=parse_formatted_time(timestamp),
|
||||
tool_name=message["function_name"] if "function_name" in message else None,
|
||||
tool_args=message["function_args"] if "function_args" in message else None,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
|
||||
id=message["id"] if "id" in message else None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user