fix: Update message schema / data type to match OAI tools style (#783)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Charles Packer
2024-01-04 15:05:43 -08:00
committed by GitHub
parent 1621774536
commit 97e8961528
4 changed files with 74 additions and 15 deletions

View File

@@ -411,12 +411,14 @@ class Agent(object):
response_message["tool_call_id"] = tool_call_id
# role: assistant (requesting tool call, set tool call ID)
messages.append(response_message) # extend conversation with assistant's reply
printd(f"Function call message: {messages[-1]}")
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
# Failure case 1: function name is wrong
function_name = response_message["function_call"]["name"]
printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
try:
function_to_call = self.functions_python[function_name]
except KeyError as e:

View File

@@ -26,7 +26,7 @@ from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.constants import MEMGPT_DIR
from memgpt.utils import printd
from memgpt.data_types import Record, Message, Passage, Source
from memgpt.data_types import Record, Message, Passage, Source, ToolCall
from datetime import datetime
@@ -71,6 +71,26 @@ class CommonVector(TypeDecorator):
return np.array(list_value)
class ToolCalls(TypeDecorator):
"""Custom type for storing List[ToolCall] as JSON"""
impl = JSON
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
return [vars(v) for v in value]
return value
def process_result_value(self, value, dialect):
if value:
return [ToolCall(**v) for v in value]
return value
Base = declarative_base()
@@ -155,8 +175,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"):
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls
tool_name = Column(String)
tool_args = Column(String)
tool_calls = Column(ToolCalls)
# tool call response info
# if role == "tool", then this must be specified
@@ -185,8 +204,7 @@ def get_db_model(table_name: str, table_type: TableType, dialect="postgresql"):
user=self.user,
text=self.text,
model=self.model,
tool_name=self.tool_name,
tool_args=self.tool_args,
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
embedding=self.embedding,
created_at=self.created_at,

View File

@@ -1,7 +1,7 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
import uuid
from abc import abstractmethod
from typing import Optional
from typing import Optional, List, Dict
import numpy as np
@@ -24,8 +24,28 @@ class Record:
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
class ToolCall(object):
def __init__(
self,
id: str,
# TODO should we include this? it's fixed to 'function' only (for now) in OAI schema
tool_call_type: str, # only 'function' is supported
# function: { 'name': ..., 'arguments': ...}
function: Dict[str, str],
):
self.id = id
self.tool_call_type = tool_call_type
self.function = function
class Message(Record):
"""Representation of a message sent from the agent -> user. Also includes function calls."""
"""Representation of a message sent.
Messages can be:
- agent->user (role=='agent')
- user->agent and system->agent (role=='user')
- or function/tool call returns (role=='function'/'tool').
"""
def __init__(
self,
@@ -36,9 +56,8 @@ class Message(Record):
model: str, # model used to make function call
user: Optional[str] = None, # optional participant name
created_at: Optional[str] = None,
tool_name: Optional[str] = None, # name of tool used
tool_args: Optional[str] = None, # args of tool used
tool_call_id: Optional[str] = None, # id of tool call
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
tool_call_id: Optional[str] = None,
embedding: Optional[np.ndarray] = None,
id: Optional[str] = None,
):
@@ -54,8 +73,13 @@ class Message(Record):
self.user = user
# tool (i.e. function) call info (optional)
self.tool_name = tool_name
self.tool_args = tool_args
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
self.tool_calls = tool_calls
# if role == "tool", then this must be specified
# if role != "tool", this must be null
self.tool_call_id = tool_call_id
# embedding (optional)

View File

@@ -7,7 +7,7 @@ from memgpt.memory import (
EmbeddingArchivalMemory,
)
from memgpt.utils import get_local_time, printd
from memgpt.data_types import Message
from memgpt.data_types import Message, ToolCall
from memgpt.config import MemGPTConfig
from datetime import datetime
@@ -116,6 +116,22 @@ class LocalStateManager(PersistenceManager):
timestamp = message_json["timestamp"]
message = message_json["message"]
# TODO: change this when we fully migrate to tool calls API
if "function_call" in message:
tool_calls = [
ToolCall(
id=message["tool_call_id"],
tool_call_type="function",
function={
"name": message["function_call"]["name"],
"arguments": message["function_call"]["arguments"],
},
)
]
printd(f"Saving tool calls {[vars(tc) for tc in tool_calls]}")
else:
tool_calls = None
return Message(
user_id=self.config.anon_clientid,
agent_id=self.agent_config.name,
@@ -123,8 +139,7 @@ class LocalStateManager(PersistenceManager):
text=message["content"],
model=self.agent_config.model,
created_at=parse_formatted_time(timestamp),
tool_name=message["function_name"] if "function_name" in message else None,
tool_args=message["function_args"] if "function_args" in message else None,
tool_calls=tool_calls,
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
id=message["id"] if "id" in message else None,
)