feat: add Google AI Gemini Pro support (#1209)

This commit is contained in:
Charles Packer
2024-04-10 19:43:44 -07:00
committed by GitHub
parent edcc143365
commit 703ea47d72
19 changed files with 1355 additions and 539 deletions

View File

@@ -15,7 +15,7 @@ from memgpt.interface import AgentInterface
from memgpt.persistence_manager import LocalStateManager
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from memgpt.memory import CoreMemory as InContextMemory, summarize_messages, ArchivalMemory, RecallMemory
from memgpt.llm_api_tools import create, is_context_overflow_error
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
from memgpt.utils import (
get_utc_time,
create_random_username,
@@ -400,7 +400,7 @@ class Agent(object):
def _get_ai_reply(
self,
message_sequence: List[dict],
message_sequence: List[Message],
function_call: str = "auto",
first_message: bool = False, # hint
) -> chat_completion_response.ChatCompletionResponse:
@@ -694,12 +694,12 @@ class Agent(object):
self.interface.user_message(user_message.text, msg_obj=user_message)
input_message_sequence = self.messages + [user_message.to_openai_dict()]
input_message_sequence = self._messages + [user_message]
# Alternatively, the requestor can send an empty user message
else:
input_message_sequence = self.messages
input_message_sequence = self._messages
if len(input_message_sequence) > 1 and input_message_sequence[-1]["role"] != "user":
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
# Step 1: send the conversation and available functions to GPT
@@ -858,14 +858,14 @@ class Agent(object):
printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...")
cutoff += 1
message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) <= 1:
# This prevents a potential infinite loop of summarizing the same message over and over
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]"
)
else:
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}")
# We can't do summarize logic properly if context_window is undefined
if self.agent_state.llm_config.context_window is None:

View File

@@ -18,7 +18,9 @@ from memgpt.constants import LLM_MAX_TOKENS
from memgpt.constants import MEMGPT_DIR
from memgpt.credentials import MemGPTCredentials, SUPPORTED_AUTH_TYPES
from memgpt.data_types import User, LLMConfig, EmbeddingConfig
from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.llm_api.openai import openai_get_model_list
from memgpt.llm_api.azure_openai import azure_openai_get_model_list
from memgpt.llm_api.google_ai import google_ai_get_model_list, google_ai_get_model_context_window
from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.server.utils import shorten_key_middle
@@ -45,11 +47,16 @@ def get_azure_credentials():
return creds
def get_openai_credentials():
openai_key = os.getenv("OPENAI_API_KEY")
def get_openai_credentials() -> Optional[str]:
openai_key = os.getenv("OPENAI_API_KEY", None)
return openai_key
def get_google_ai_credentials() -> Optional[str]:
google_ai_key = os.getenv("GOOGLE_AI_API_KEY", None)
return google_ai_key
def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials):
# configure model endpoint
model_endpoint_type, model_endpoint = None, None
@@ -59,11 +66,12 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [
"openai",
"azure",
"google_ai",
]: # local model
default_model_endpoint_type = "local"
provider = questionary.select(
"Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type
"Select LLM inference provider:", choices=["openai", "azure", "google_ai", "local"], default=default_model_endpoint_type
).ask()
if provider is None:
raise KeyboardInterrupt
@@ -131,6 +139,51 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
model_endpoint_type = "azure"
model_endpoint = azure_creds["azure_endpoint"]
elif provider == "google_ai":
# check for key
if credentials.google_ai_key is None:
# allow key to get pulled from env vars
google_ai_key = get_google_ai_credentials()
# if we still can't find it, ask for it as input
if google_ai_key is None:
while google_ai_key is None or len(google_ai_key) == 0:
# Ask for API key as input
google_ai_key = questionary.password(
"Enter your Google AI (Gemini) API key (see https://aistudio.google.com/app/apikey):"
).ask()
if google_ai_key is None:
raise KeyboardInterrupt
credentials.google_ai_key = google_ai_key
else:
# Give the user an opportunity to overwrite the key
google_ai_key = None
default_input = shorten_key_middle(credentials.google_ai_key)
google_ai_key = questionary.password(
"Enter your Google AI (Gemini) API key (see https://aistudio.google.com/app/apikey):",
default=default_input,
).ask()
if google_ai_key is None:
raise KeyboardInterrupt
# If the user modified it, use the new one
if google_ai_key != default_input:
credentials.google_ai_key = google_ai_key
default_input = os.getenv("GOOGLE_AI_SERVICE_ENDPOINT", None)
if default_input is None:
default_input = "generativelanguage"
google_ai_service_endpoint = questionary.text(
"Enter your Google AI (Gemini) service endpoint (see https://ai.google.dev/api/rest):",
default=default_input,
).ask()
credentials.google_ai_service_endpoint = google_ai_service_endpoint
# write out the credentials
credentials.save()
model_endpoint_type = "google_ai"
else: # local models
# backend_options_old = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"]
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
@@ -223,6 +276,21 @@ def get_model_options(
else:
model_options = [obj["id"] for obj in fetched_model_options_response["data"]]
elif model_endpoint_type == "google_ai":
if credentials.google_ai_key is None:
raise ValueError("Missing Google AI API key")
if credentials.google_ai_service_endpoint is None:
raise ValueError("Missing Google AI service endpoint")
model_options = google_ai_get_model_list(
service_endpoint=credentials.google_ai_service_endpoint, api_key=credentials.google_ai_key
)
model_options = [str(m["name"]) for m in model_options]
model_options = [mo[len("models/") :] if mo.startswith("models/") else mo for mo in model_options]
# TODO remove manual filtering for gemini-pro
model_options = [mo for mo in model_options if str(mo).startswith("gemini") and "-pro" in str(mo)]
# model_options = ["gemini-pro"]
else:
# Attempt to do OpenAI endpoint style model fetching
# TODO support local auth with api-key header
@@ -294,6 +362,26 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model is None:
raise KeyboardInterrupt
elif model_endpoint_type == "google_ai":
try:
fetched_model_options = get_model_options(
credentials=credentials, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
except Exception as e:
# NOTE: if this fails, it means the user's key is probably bad
typer.secho(
f"Failed to get model list from {model_endpoint} - make sure your API key and endpoints are correct!", fg=typer.colors.RED
)
raise e
model = questionary.select(
"Select default model:",
choices=fetched_model_options,
default=fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt
else: # local models
# ask about local auth
@@ -412,7 +500,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
# set: context_window
if str(model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length
context_length_options = [
str(2**12), # 4096
str(2**13), # 8192
@@ -421,13 +509,40 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
str(2**18), # 262144
"custom", # enter yourself
]
context_window_input = questionary.select(
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask()
if context_window_input is None:
raise KeyboardInterrupt
if model_endpoint_type == "google_ai":
try:
fetched_context_window = str(
google_ai_get_model_context_window(
service_endpoint=credentials.google_ai_service_endpoint, api_key=credentials.google_ai_key, model=model
)
)
print(f"Got context window {fetched_context_window} for model {model} (from Google API)")
context_length_options = [
fetched_context_window,
"custom",
]
except:
print(f"Failed to get model details for model '{model}' on Google AI API")
context_window_input = questionary.select(
"Select your model's context window (see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions):",
choices=context_length_options,
default=context_length_options[0],
).ask()
if context_window_input is None:
raise KeyboardInterrupt
else:
# Ask the user to specify the context length
context_window_input = questionary.select(
"Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):",
choices=context_length_options,
default=str(LLM_MAX_TOKENS["DEFAULT"]),
).ask()
if context_window_input is None:
raise KeyboardInterrupt
# If custom, ask for input
if context_window_input == "custom":

View File

@@ -86,6 +86,8 @@ MESSAGE_SUMMARY_WARNING_STR = " ".join(
)
# The fraction of tokens we truncate down to
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC = 0.75
# The ackknowledgement message used in the summarize sequence
MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the message (and only the summary, nothing else) once I receive the conversation history. I'm ready."
# Even when summarizing, we want to keep a handful of recent messages
# These serve as in-context examples of how to use functions / what user messages look like

View File

@@ -28,6 +28,10 @@ class MemGPTCredentials:
openai_auth_type: str = "bearer_token"
openai_key: Optional[str] = None
# gemini config
google_ai_key: Optional[str] = None
google_ai_service_endpoint: Optional[str] = None
# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
@@ -70,6 +74,9 @@ class MemGPTCredentials:
"azure_embedding_version": get_field(config, "azure", "embedding_version"),
"azure_embedding_endpoint": get_field(config, "azure", "embedding_endpoint"),
"azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"),
# gemini
"google_ai_key": get_field(config, "google_ai", "key"),
"google_ai_service_endpoint": get_field(config, "google_ai", "service_endpoint"),
# open llm
"openllm_auth_type": get_field(config, "openllm", "auth_type"),
"openllm_key": get_field(config, "openllm", "key"),
@@ -102,7 +109,11 @@ class MemGPTCredentials:
set_field(config, "azure", "embedding_endpoint", self.azure_embedding_endpoint)
set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment)
# openai config
# gemini
set_field(config, "google_ai", "key", self.google_ai_key)
set_field(config, "google_ai", "service_endpoint", self.google_ai_service_endpoint)
# openllm config
set_field(config, "openllm", "auth_type", self.openllm_auth_type)
set_field(config, "openllm", "key", self.openllm_key)

View File

@@ -1,6 +1,7 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
import uuid
import json
from datetime import datetime, timezone
from typing import Optional, List, Dict, TypeVar
import numpy as np
@@ -18,6 +19,7 @@ from memgpt.constants import (
from memgpt.utils import get_utc_time, create_uuid_from_string
from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd, is_utc_datetime
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
class Record:
@@ -82,6 +84,7 @@ class Message(Record):
created_at: Optional[datetime] = None,
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
tool_call_id: Optional[str] = None,
# tool_call_name: Optional[str] = None, # not technically OpenAI spec, but it can be helpful to have on-hand
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
embedding_model: Optional[str] = None,
@@ -238,7 +241,7 @@ class Message(Record):
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)
def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN):
def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
"""Go from Message class to ChatCompletion message object"""
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
@@ -285,11 +288,117 @@ class Message(Record):
"role": self.role,
"tool_call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id,
}
else:
raise ValueError(self.role)
return openai_message
def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:
"""Go from Message class to Google AI REST message object
type Content: https://ai.google.dev/api/rest/v1/Content / https://ai.google.dev/api/rest/v1beta/Content
parts[]: Part
role: str ('user' or 'model')
"""
if self.role != "tool" and self.name is not None:
raise UserWarning(f"Using Google AI with non-null 'name' field ({self.name}) not yet supported.")
if self.role == "system":
# NOTE: Gemini API doesn't have a 'system' role, use 'user' instead
# https://www.reddit.com/r/Bard/comments/1b90i8o/does_gemini_have_a_system_prompt_option_while/
google_ai_message = {
"role": "user", # NOTE: no 'system'
"parts": [{"text": self.text}],
}
elif self.role == "user":
assert all([v is not None for v in [self.text, self.role]]), vars(self)
google_ai_message = {
"role": "user",
"parts": [{"text": self.text}],
}
elif self.role == "assistant":
assert self.tool_calls is not None or self.text is not None
google_ai_message = {
"role": "model", # NOTE: different
}
# NOTE: Google AI API doesn't allow non-null content + function call
# To get around this, just two a two part message, inner thoughts first then
parts = []
if not put_inner_thoughts_in_kwargs and self.text is not None:
# NOTE: ideally we do multi-part for CoT / inner thoughts + function call, but Google AI API doesn't allow it
raise NotImplementedError
parts.append({"text": self.text})
if self.tool_calls is not None:
# NOTE: implied support for multiple calls
for tool_call in self.tool_calls:
function_name = tool_call.function["name"]
function_args = tool_call.function["arguments"]
try:
# NOTE: Google AI wants actual JSON objects, not strings
function_args = json.loads(function_args)
except:
raise UserWarning(f"Failed to parse JSON function args: {function_args}")
function_args = {"args": function_args}
if put_inner_thoughts_in_kwargs and self.text is not None:
assert "inner_thoughts" not in function_args, function_args
assert len(self.tool_calls) == 1
function_args[INNER_THOUGHTS_KWARG] = self.text
parts.append(
{
"functionCall": {
"name": function_name,
"args": function_args,
}
}
)
else:
assert self.text is not None
parts.append({"text": self.text})
google_ai_message["parts"] = parts
elif self.role == "tool":
# NOTE: Significantly different tool calling format, more similar to function calling format
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
if self.name is None:
raise UserWarning(f"Couldn't find function name on tool call, defaulting to tool ID instead.")
function_name = self.tool_call_id
else:
function_name = self.name
# NOTE: Google AI API wants the function response as JSON only, no string
try:
function_response = json.loads(self.text)
except:
function_response = {"function_response": self.text}
google_ai_message = {
"role": "function",
"parts": [
{
"functionResponse": {
"name": function_name,
"response": {
"name": function_name, # NOTE: name twice... why?
"content": function_response,
},
}
}
],
}
else:
raise ValueError(self.role)
return google_ai_message
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""

View File

@@ -2,7 +2,7 @@ from typing import Optional
import os
import json
import requests
import uuid
from memgpt.constants import (
JSON_LOADS_STRICT,
@@ -11,7 +11,8 @@ from memgpt.constants import (
MAX_PAUSE_HEARTBEATS,
JSON_ENSURE_ASCII,
)
from memgpt.llm_api_tools import create
from memgpt.llm_api.llm_api_tools import create
from memgpt.data_types import Message
def message_chatgpt(self, message: str):
@@ -24,15 +25,15 @@ def message_chatgpt(self, message: str):
Returns:
str: Reply message from ChatGPT
"""
dummy_user_id = uuid.uuid4()
dummy_agent_id = uuid.uuid4()
message_sequence = [
{"role": "system", "content": MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE},
{"role": "user", "content": str(message)},
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE),
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)),
]
response = create(
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
messages=message_sequence,
# functions=functions,
# function_call=function_call,
)
reply = response.choices[0].message.content

View File

View File

@@ -0,0 +1,154 @@
import requests
from typing import Union
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.utils import smart_urljoin
MODEL_TO_AZURE_ENGINE = {
"gpt-4-1106-preview": "gpt-4",
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-3.5": "gpt-35-turbo",
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
}
def clean_azure_endpoint(raw_endpoint_name: str) -> str:
"""Make sure the endpoint is of format 'https://YOUR_RESOURCE_NAME.openai.azure.com'"""
if raw_endpoint_name is None:
raise ValueError(raw_endpoint_name)
endpoint_address = raw_endpoint_name.strip("/").replace(".openai.azure.com", "")
endpoint_address = endpoint_address.replace("http://", "")
endpoint_address = endpoint_address.replace("https://", "")
return endpoint_address
def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
from memgpt.utils import printd
# https://xxx.openai.azure.com/openai/models?api-version=xxx
url = smart_urljoin(url, "openai")
url = smart_urljoin(url, f"models?api-version={api_version}")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["api-key"] = f"{api_key}"
printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e
def azure_openai_chat_completions_request(
resource_name: str, deployment_id: str, api_version: str, api_key: str, data: dict
) -> ChatCompletionResponse:
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
from memgpt.utils import printd
assert resource_name is not None, "Missing required field when calling Azure OpenAI"
assert deployment_id is not None, "Missing required field when calling Azure OpenAI"
assert api_version is not None, "Missing required field when calling Azure OpenAI"
assert api_key is not None, "Missing required field when calling Azure OpenAI"
resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")
if "tools" in data and data["tools"] is None:
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
# NOTE: azure openai does not include "content" in the response when it is None, so we need to add it
if "content" not in response["choices"][0].get("message"):
response["choices"][0]["message"]["content"] = None
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def azure_openai_embeddings_request(
resource_name: str, deployment_id: str, api_version: str, api_key: str, data: dict
) -> EmbeddingResponse:
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings"""
from memgpt.utils import printd
resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/embeddings?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e

463
memgpt/llm_api/google_ai.py Normal file
View File

@@ -0,0 +1,463 @@
import requests
import json
import uuid
from typing import Union, List, Optional
from memgpt.models.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, FunctionCall, UsageStatistics
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.utils import smart_urljoin, get_tool_call_id, get_utc_time
from memgpt.local_llm.utils import count_tokens
from memgpt.local_llm.json_parser import clean_json_string_extra_backslash
from memgpt.constants import NON_USER_MSG_PREFIX, JSON_ENSURE_ASCII
# from memgpt.data_types import ToolCall
SUPPORTED_MODELS = [
"gemini-pro",
]
def google_ai_get_model_details(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> List[dict]:
from memgpt.utils import printd
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
if key_in_header:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}"
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
else:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}?key={api_key}"
headers = {"Content-Type": "application/json"}
try:
response = requests.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
# Grab the models out
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def google_ai_get_model_context_window(service_endpoint: str, api_key: str, model: str, key_in_header: bool = True) -> int:
model_details = google_ai_get_model_details(
service_endpoint=service_endpoint, api_key=api_key, model=model, key_in_header=key_in_header
)
# TODO should this be:
# return model_details["inputTokenLimit"] + model_details["outputTokenLimit"]
return int(model_details["inputTokenLimit"])
def google_ai_get_model_list(service_endpoint: str, api_key: str, key_in_header: bool = True) -> List[dict]:
from memgpt.utils import printd
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
if key_in_header:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models"
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
else:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models?key={api_key}"
headers = {"Content-Type": "application/json"}
try:
response = requests.get(url, headers=headers)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
# Grab the models out
model_list = response["models"]
return model_list
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
In MemGPT, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
so there is no natural follow-up 'model' role message.
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
with role == 'model' that is placed in-betweeen and function output
(role == 'tool') and user message (role == 'user').
"""
dummy_yield_message = {"role": "model", "parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}]}
messages_with_padding = []
for i, message in enumerate(messages):
messages_with_padding.append(message)
# Check if the current message role is 'tool' and the next message role is 'user'
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
messages_with_padding.append(dummy_yield_message)
return messages_with_padding
# TODO use pydantic model as input
def to_google_ai(openai_message_dict: dict) -> dict:
# TODO supports "parts" as part of multimodal support
assert not isinstance(openai_message_dict["content"], list), "Multi-part content is message not yet supported"
if openai_message_dict["role"] == "user":
google_ai_message_dict = {
"role": "user",
"parts": [{"text": openai_message_dict["content"]}],
}
elif openai_message_dict["role"] == "assistant":
google_ai_message_dict = {
"role": "model", # NOTE: diff
"parts": [{"text": openai_message_dict["content"]}],
}
elif openai_message_dict["role"] == "tool":
google_ai_message_dict = {
"role": "function", # NOTE: diff
"parts": [{"text": openai_message_dict["content"]}],
}
else:
raise ValueError(f"Unsupported conversion (OpenAI -> Google AI) from role {openai_message_dict['role']}")
# TODO convert return type to pydantic
def convert_tools_to_google_ai_format(tools: List[Tool], inner_thoughts_in_kwargs: Optional[bool] = True) -> List[dict]:
"""
OpenAI style:
"tools": [{
"type": "function",
"function": {
"name": "find_movies",
"description": "find ....",
"parameters": {
"type": "object",
"properties": {
PARAM: {
"type": PARAM_TYPE, # eg "string"
"description": PARAM_DESCRIPTION,
},
...
},
"required": List[str],
}
}
}
]
Google AI style:
"tools": [{
"functionDeclarations": [{
"name": "find_movies",
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
"parameters": {
"type": "OBJECT",
"properties": {
"location": {
"type": "STRING",
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
},
"description": {
"type": "STRING",
"description": "Any kind of description including category or genre, title words, attributes, etc."
}
},
"required": ["description"]
}
}, {
"name": "find_theaters",
...
"""
function_list = [
dict(
name=t.function.name,
description=t.function.description,
parameters=t.function.parameters, # TODO need to unpack
)
for t in tools
]
# Correct casing + add inner thoughts if needed
for func in function_list:
func["parameters"]["type"] = "OBJECT"
for param_name, param_fields in func["parameters"]["properties"].items():
param_fields["type"] = param_fields["type"].upper()
# Add inner thoughts
if inner_thoughts_in_kwargs:
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
func["parameters"]["properties"][INNER_THOUGHTS_KWARG] = {
"type": "STRING",
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
}
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG)
return [{"functionDeclarations": function_list}]
def convert_google_ai_response_to_chatcompletion(
response_json: dict, # REST response from Google AI API
model: str, # Required since not returned
input_messages: Optional[List[dict]] = None, # Required if the API doesn't return UsageMetadata
pull_inner_thoughts_from_args: Optional[bool] = True,
) -> ChatCompletionResponse:
"""Google AI API response format is not the same as ChatCompletion, requires unpacking
Example:
{
"candidates": [
{
"content": {
"parts": [
{
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
}
]
}
}
],
"usageMetadata": {
"promptTokenCount": 9,
"candidatesTokenCount": 27,
"totalTokenCount": 36
}
}
"""
try:
choices = []
for candidate in response_json["candidates"]:
content = candidate["content"]
role = content["role"]
assert role == "model", f"Unknown role in response: {role}"
parts = content["parts"]
# TODO support parts / multimodal
assert len(parts) == 1, f"Multi-part not yet supported:\n{parts}"
response_message = parts[0]
# Convert the actual message style to OpenAI style
if "functionCall" in response_message and response_message["functionCall"] is not None:
function_call = response_message["functionCall"]
assert isinstance(function_call, dict), function_call
function_name = function_call["name"]
assert isinstance(function_name, str), function_name
function_args = function_call["args"]
assert isinstance(function_args, dict), function_args
# NOTE: this also involves stripping the inner monologue out of the function
if pull_inner_thoughts_from_args:
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
assert INNER_THOUGHTS_KWARG in function_args, f"Couldn't find inner thoughts in function args:\n{function_call}"
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG)
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
else:
inner_thoughts = None
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
tool_calls=[
ToolCall(
id=get_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=clean_json_string_extra_backslash(json.dumps(function_args)),
),
)
],
)
else:
# Inner thoughts are the content by default
inner_thoughts = response_message["text"]
# Google AI API doesn't generate tool call IDs
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
)
# Google AI API uses different finish reason strings than OpenAI
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
finish_reason = candidate["finishReason"]
if finish_reason == "STOP":
openai_finish_reason = (
"function_call"
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
else "stop"
)
elif finish_reason == "MAX_TOKENS":
openai_finish_reason = "length"
elif finish_reason == "SAFETY":
openai_finish_reason = "content_filter"
elif finish_reason == "RECITATION":
openai_finish_reason = "content_filter"
else:
raise ValueError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
choices.append(
Choice(
finish_reason=openai_finish_reason,
index=candidate["index"],
message=openai_response_message,
)
)
if len(choices) > 1:
raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
# "usageMetadata": {
# "promptTokenCount": 9,
# "candidatesTokenCount": 27,
# "totalTokenCount": 36
# }
if "usageMetadata" in response_json:
usage = UsageStatistics(
prompt_tokens=response_json["usageMetadata"]["promptTokenCount"],
completion_tokens=response_json["usageMetadata"]["candidatesTokenCount"],
total_tokens=response_json["usageMetadata"]["totalTokenCount"],
)
else:
# Count it ourselves
assert input_messages is not None, f"Didn't get UsageMetadata from the API response, so input_messages is required"
prompt_tokens = count_tokens(
json.dumps(input_messages, ensure_ascii=JSON_ENSURE_ASCII)
) # NOTE: this is a very rough approximation
completion_tokens = count_tokens(
json.dumps(openai_response_message.model_dump(), ensure_ascii=JSON_ENSURE_ASCII)
) # NOTE: this is also approximate
total_tokens = prompt_tokens + completion_tokens
usage = UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
response_id = str(uuid.uuid4())
return ChatCompletionResponse(
id=response_id,
choices=choices,
model=model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time(),
usage=usage,
)
except KeyError as e:
raise e
# TODO convert 'data' type to pydantic
def google_ai_chat_completions_request(
service_endpoint: str,
model: str,
api_key: str,
data: dict,
key_in_header: bool = True,
add_postfunc_model_messages: bool = True,
# NOTE: Google AI API doesn't support mixing parts 'text' and 'function',
# so there's no clean way to put inner thoughts in the same message as a function call
inner_thoughts_in_kwargs: bool = True,
) -> ChatCompletionResponse:
"""https://ai.google.dev/docs/function_calling
From https://ai.google.dev/api/rest#service-endpoint:
"A service endpoint is a base URL that specifies the network address of an API service.
One service might have multiple service endpoints.
This service has the following service endpoint and all URIs below are relative to this service endpoint:
https://xxx.googleapis.com
"""
from memgpt.utils import printd
assert service_endpoint is not None, "Missing service_endpoint when calling Google AI"
assert api_key is not None, "Missing api_key when calling Google AI"
assert model in SUPPORTED_MODELS, f"Model '{model}' not in supported models: {', '.join(SUPPORTED_MODELS)}"
# Two ways to pass the key: https://ai.google.dev/tutorials/setup
if key_in_header:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent"
headers = {"Content-Type": "application/json", "x-goog-api-key": api_key}
else:
url = f"https://{service_endpoint}.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}"
headers = {"Content-Type": "application/json"}
# data["contents"][-1]["role"] = "model"
if add_postfunc_model_messages:
data["contents"] = add_dummy_model_messages(data["contents"])
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
# Convert Google AI response to ChatCompletion style
return convert_google_ai_response_to_chatcompletion(
response_json=response,
model=model,
input_messages=data["contents"],
pull_inner_thoughts_from_args=inner_thoughts_in_kwargs,
)
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
# Print the HTTP status code
print(f"HTTP Error: {http_err.response.status_code}")
# Print the response content (error message from server)
print(f"Message: {http_err.response.text}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e

View File

@@ -0,0 +1,251 @@
import random
import time
import requests
import os
import time
from typing import List
from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
from memgpt.data_types import AgentState, Message
from memgpt.llm_api.openai import openai_chat_completions_request
from memgpt.llm_api.azure_openai import azure_openai_chat_completions_request, MODEL_TO_AZURE_ENGINE
from memgpt.llm_api.google_ai import (
google_ai_chat_completions_request,
convert_tools_to_google_ai_format,
)
def is_context_overflow_error(exception: requests.exceptions.RequestException) -> bool:
"""Checks if an exception is due to context overflow (based on common OpenAI response messages)"""
from memgpt.utils import printd
match_string = "maximum context length"
# Backwards compatibility with openai python package/client v0.28 (pre-v1 client migration)
if match_string in str(exception):
printd(f"Found '{match_string}' in str(exception)={(str(exception))}")
return True
# Based on python requests + OpenAI REST API (/v1)
elif isinstance(exception, requests.exceptions.HTTPError):
if exception.response is not None and "application/json" in exception.response.headers.get("Content-Type", ""):
try:
error_details = exception.response.json()
if "error" not in error_details:
printd(f"HTTPError occurred, but couldn't find error field: {error_details}")
return False
else:
error_details = error_details["error"]
# Check for the specific error code
if error_details.get("code") == "context_length_exceeded":
printd(f"HTTPError occurred, caught error code {error_details.get('code')}")
return True
# Soft-check for "maximum context length" inside of the message
elif error_details.get("message") and "maximum context length" in error_details.get("message"):
printd(f"HTTPError occurred, found '{match_string}' in error message contents ({error_details})")
return True
else:
printd(f"HTTPError occurred, but unknown error message: {error_details}")
return False
except ValueError:
# JSON decoding failed
printd(f"HTTPError occurred ({exception}), but no JSON error message.")
# Generic fail
else:
return False
def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 20,
# List of OpenAI error codes: https://github.com/openai/openai-python/blob/17ac6779958b2b74999c634c4ea4c7b74906027a/src/openai/_client.py#L227-L250
# 429 = rate limit
error_codes: tuple = (429,),
):
"""Retry a function with exponential backoff."""
def wrapper(*args, **kwargs):
from memgpt.utils import printd
# Initialize variables
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
except requests.exceptions.HTTPError as http_err:
# Retry on specified errors
if http_err.response.status_code in error_codes:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
# Sleep for the delay
# printd(f"Got a rate limit error ('{http_err}') on LLM backend request, waiting {int(delay)}s then retrying...")
print(
f"{CLI_WARNING_PREFIX}Got a rate limit error ('{http_err}') on LLM backend request, waiting {int(delay)}s then retrying..."
)
time.sleep(delay)
else:
# For other HTTP errors, re-raise the exception
raise
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
@retry_with_exponential_backoff
def create(
agent_state: AgentState,
messages: List[Message],
functions=None,
functions_python=None,
function_call="auto",
# hint
first_message=False,
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming=True,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from memgpt.utils import printd
printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}")
# TODO eventually refactor so that credentials are passed through
credentials = MemGPTCredentials.load()
if function_call and not functions:
printd("unsetting function_call because functions is None")
function_call = None
# openai
if agent_state.llm_config.model_endpoint_type == "openai":
# TODO do the same for Azure?
if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from MemGPT config file")
if use_tool_naming:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
)
else:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
)
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
data=data,
)
# azure
elif agent_state.llm_config.model_endpoint_type == "azure":
azure_deployment = (
credentials.azure_deployment
if credentials.azure_deployment is not None
else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model]
)
if use_tool_naming:
data = dict(
# NOTE: don't pass model to Azure calls, that is the deployment_id
# model=agent_config.model,
messages=messages,
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
)
else:
data = dict(
# NOTE: don't pass model to Azure calls, that is the deployment_id
# model=agent_config.model,
messages=messages,
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
)
return azure_openai_chat_completions_request(
resource_name=credentials.azure_endpoint,
deployment_id=azure_deployment,
api_version=credentials.azure_version,
api_key=credentials.azure_key,
data=data,
)
elif agent_state.llm_config.model_endpoint_type == "google_ai":
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Google AI API requests")
# NOTE: until Google AI supports CoT / text alongside function calls,
# we need to put it in a kwarg (unless we want to split the message into two)
google_ai_inner_thoughts_in_kwarg = True
if functions is not None:
tools = [{"type": "function", "function": f} for f in functions]
tools = [Tool(**t) for t in tools]
tools = (convert_tools_to_google_ai_format(tools, inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg),)
else:
tools = None
return google_ai_chat_completions_request(
inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg,
service_endpoint=credentials.google_ai_service_endpoint,
model=agent_state.llm_config.model,
api_key=credentials.google_ai_key,
# see structure of payload here: https://ai.google.dev/docs/function_calling
data=dict(
contents=[m.to_google_ai_dict() for m in messages],
tools=tools,
),
)
# local model
else:
return get_chat_completion(
model=agent_state.llm_config.model,
messages=messages,
functions=functions,
functions_python=functions_python,
function_call=function_call,
context_window=agent_state.llm_config.context_window,
endpoint=agent_state.llm_config.model_endpoint,
endpoint_type=agent_state.llm_config.model_endpoint_type,
wrapper=agent_state.llm_config.model_wrapper,
user=str(agent_state.user_id),
# hint
first_message=first_message,
# auth-related
auth_type=credentials.openllm_auth_type,
auth_key=credentials.openllm_key,
)

138
memgpt/llm_api/openai.py Normal file
View File

@@ -0,0 +1,138 @@
import requests
import time
from typing import Union, Optional
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_request import ChatCompletionRequest
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.utils import smart_urljoin
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from memgpt.utils import printd
# In some cases we may want to double-check the URL and do basic correction, eg:
# In MemGPT config the address for vLLM is w/o a /v1 suffix for simplicity
# However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit
if fix_url:
if not url.endswith("/v1"):
url = smart_urljoin(url, "v1")
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e
def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletionRequest) -> ChatCompletionResponse:
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
from memgpt.utils import printd
url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = data.model_dump(exclude_none=True)
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")
if "tools" in data and data["tools"] is None:
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
printd(f"Sending request to {url}")
try:
# Example code to trigger a rate limit response:
# mock_response = requests.Response()
# mock_response.status_code = 429
# http_error = requests.exceptions.HTTPError("429 Client Error: Too Many Requests")
# http_error.response = mock_response
# raise http_error
# Example code to trigger a context overflow response (for an 8k model)
# data["messages"][-1]["content"] = " ".join(["repeat after me this is not a fluke"] * 1000)
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def openai_embeddings_request(url: str, api_key: str, data: dict) -> EmbeddingResponse:
"""https://platform.openai.com/docs/api-reference/embeddings/create"""
from memgpt.utils import printd
url = smart_urljoin(url, "embeddings")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e

View File

@@ -1,495 +0,0 @@
import random
import time
import requests
import time
from typing import Union, Optional
import urllib
from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.data_types import AgentState
MODEL_TO_AZURE_ENGINE = {
"gpt-4-1106-preview": "gpt-4",
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-3.5": "gpt-35-turbo",
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
}
def is_context_overflow_error(exception):
from memgpt.utils import printd
match_string = "maximum context length"
# Backwards compatibility with openai python package/client v0.28 (pre-v1 client migration)
if match_string in str(exception):
printd(f"Found '{match_string}' in str(exception)={(str(exception))}")
return True
# Based on python requests + OpenAI REST API (/v1)
elif isinstance(exception, requests.exceptions.HTTPError):
if exception.response is not None and "application/json" in exception.response.headers.get("Content-Type", ""):
try:
error_details = exception.response.json()
if "error" not in error_details:
printd(f"HTTPError occurred, but couldn't find error field: {error_details}")
return False
else:
error_details = error_details["error"]
# Check for the specific error code
if error_details.get("code") == "context_length_exceeded":
printd(f"HTTPError occurred, caught error code {error_details.get('code')}")
return True
# Soft-check for "maximum context length" inside of the message
elif error_details.get("message") and "maximum context length" in error_details.get("message"):
printd(f"HTTPError occurred, found '{match_string}' in error message contents ({error_details})")
return True
else:
printd(f"HTTPError occurred, but unknown error message: {error_details}")
return False
except ValueError:
# JSON decoding failed
printd(f"HTTPError occurred ({exception}), but no JSON error message.")
# Generic fail
else:
return False
def smart_urljoin(base_url, relative_url):
"""urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off"""
if not base_url.endswith("/"):
base_url += "/"
return urllib.parse.urljoin(base_url, relative_url)
def clean_azure_endpoint(raw_endpoint_name):
"""Make sure the endpoint is of format 'https://YOUR_RESOURCE_NAME.openai.azure.com'"""
if raw_endpoint_name is None:
raise ValueError(raw_endpoint_name)
endpoint_address = raw_endpoint_name.strip("/").replace(".openai.azure.com", "")
endpoint_address = endpoint_address.replace("http://", "")
endpoint_address = endpoint_address.replace("https://", "")
return endpoint_address
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from memgpt.utils import printd
# In some cases we may want to double-check the URL and do basic correction, eg:
# In MemGPT config the address for vLLM is w/o a /v1 suffix for simplicity
# However if we're treating the server as an OpenAI proxy we want the /v1 suffix on our model hit
if fix_url:
if not url.endswith("/v1"):
url = smart_urljoin(url, "v1")
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e
def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:
"""https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP"""
from memgpt.utils import printd
# https://xxx.openai.azure.com/openai/models?api-version=xxx
url = smart_urljoin(url, "openai")
url = smart_urljoin(url, f"models?api-version={api_version}")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["api-key"] = f"{api_key}"
printd(f"Sending request to {url}")
try:
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response = {response}")
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e
def openai_chat_completions_request(url, api_key, data):
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
from memgpt.utils import printd
url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")
if "tools" in data and data["tools"] is None:
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
printd(f"Sending request to {url}")
try:
# Example code to trigger a rate limit response:
# mock_response = requests.Response()
# mock_response.status_code = 429
# http_error = requests.exceptions.HTTPError("429 Client Error: Too Many Requests")
# http_error.response = mock_response
# raise http_error
# Example code to trigger a context overflow response (for an 8k model)
# data["messages"][-1]["content"] = " ".join(["repeat after me this is not a fluke"] * 1000)
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def openai_embeddings_request(url, api_key, data):
"""https://platform.openai.com/docs/api-reference/embeddings/create"""
from memgpt.utils import printd
url = smart_urljoin(url, "embeddings")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def azure_openai_chat_completions_request(resource_name, deployment_id, api_version, api_key, data):
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
from memgpt.utils import printd
assert resource_name is not None, "Missing required field when calling Azure OpenAI"
assert deployment_id is not None, "Missing required field when calling Azure OpenAI"
assert api_version is not None, "Missing required field when calling Azure OpenAI"
assert api_key is not None, "Missing required field when calling Azure OpenAI"
resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")
if "tools" in data and data["tools"] is None:
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
# NOTE: azure openai does not include "content" in the response when it is None, so we need to add it
if "content" not in response["choices"][0].get("message"):
response["choices"][0]["message"]["content"] = None
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def azure_openai_embeddings_request(resource_name, deployment_id, api_version, api_key, data):
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings"""
from memgpt.utils import printd
resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/embeddings?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = EmbeddingResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 20,
# List of OpenAI error codes: https://github.com/openai/openai-python/blob/17ac6779958b2b74999c634c4ea4c7b74906027a/src/openai/_client.py#L227-L250
# 429 = rate limit
error_codes: tuple = (429,),
):
"""Retry a function with exponential backoff."""
def wrapper(*args, **kwargs):
from memgpt.utils import printd
# Initialize variables
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
except requests.exceptions.HTTPError as http_err:
# Retry on specified errors
if http_err.response.status_code in error_codes:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
# Sleep for the delay
# printd(f"Got a rate limit error ('{http_err}') on LLM backend request, waiting {int(delay)}s then retrying...")
print(
f"{CLI_WARNING_PREFIX}Got a rate limit error ('{http_err}') on LLM backend request, waiting {int(delay)}s then retrying..."
)
time.sleep(delay)
else:
# For other HTTP errors, re-raise the exception
raise
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
@retry_with_exponential_backoff
def create(
agent_state: AgentState,
messages,
functions=None,
functions_python=None,
function_call="auto",
# hint
first_message=False,
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming=True,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from memgpt.utils import printd
printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}")
# TODO eventually refactor so that credentials are passed through
credentials = MemGPTCredentials.load()
if function_call and not functions:
printd("unsetting function_call because functions is None")
function_call = None
# openai
if agent_state.llm_config.model_endpoint_type == "openai":
# TODO do the same for Azure?
if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from MemGPT config file")
if use_tool_naming:
data = dict(
model=agent_state.llm_config.model,
messages=messages,
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
)
else:
data = dict(
model=agent_state.llm_config.model,
messages=messages,
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
)
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
data=data,
)
# azure
elif agent_state.llm_config.model_endpoint_type == "azure":
azure_deployment = (
credentials.azure_deployment
if credentials.azure_deployment is not None
else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model]
)
if use_tool_naming:
data = dict(
# NOTE: don't pass model to Azure calls, that is the deployment_id
# model=agent_config.model,
messages=messages,
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
)
else:
data = dict(
# NOTE: don't pass model to Azure calls, that is the deployment_id
# model=agent_config.model,
messages=messages,
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
)
return azure_openai_chat_completions_request(
resource_name=credentials.azure_endpoint,
deployment_id=azure_deployment,
api_version=credentials.azure_version,
api_key=credentials.azure_key,
data=data,
)
# local model
else:
return get_chat_completion(
model=agent_state.llm_config.model,
messages=messages,
functions=functions,
functions_python=functions_python,
function_call=function_call,
context_window=agent_state.llm_config.context_window,
endpoint=agent_state.llm_config.model_endpoint,
endpoint_type=agent_state.llm_config.model_endpoint_type,
wrapper=agent_state.llm_config.model_wrapper,
user=str(agent_state.user_id),
# hint
first_message=first_message,
# auth-related
auth_type=credentials.openllm_auth_type,
auth_key=credentials.openllm_key,
)

View File

@@ -23,3 +23,6 @@ DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K"
DEFAULT_WRAPPER = ChatMLInnerMonologueWrapper
DEFAULT_WRAPPER_NAME = "chatml"
INNER_THOUGHTS_KWARG = "inner_thoughts"
INNER_THOUGHTS_KWARG_DESCRIPTION = "Deep inner monologue private to you only."

View File

@@ -5,6 +5,18 @@ from memgpt.constants import JSON_LOADS_STRICT
from memgpt.errors import LLMJSONParsingError
def clean_json_string_extra_backslash(s):
"""Clean extra backslashes out from stringified JSON
NOTE: Google AI Gemini API likes to include these
"""
# Strip slashes that are used to escape single quotes and other backslashes
# Use json.loads to parse it correctly
while "\\\\" in s:
s = s.replace("\\\\", "\\")
return s
def replace_escaped_underscores(string: str):
"""Handles the case of escaped underscores, e.g.:

View File

@@ -1,9 +1,9 @@
import json
from .wrapper_base import LLMChatCompletionWrapper
from ..json_parser import clean_json
from ...constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from ...errors import LLMJSONParsingError
from memgpt.local_llm.llm_chat_completion_wrappers.wrapper_base import LLMChatCompletionWrapper
from memgpt.local_llm.json_parser import clean_json
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
from memgpt.errors import LLMJSONParsingError
PREFIX_HINT = """# Reminders:
@@ -75,7 +75,9 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
func_str += f"\n description: {schema['description']}"
func_str += f"\n params:"
if add_inner_thoughts:
func_str += f"\n inner_thoughts: Deep inner monologue private to you only."
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
func_str += f"\n {INNER_THOUGHTS_KWARG}: {INNER_THOUGHTS_KWARG_DESCRIPTION}"
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"

View File

@@ -1,6 +1,7 @@
import os
import sys
import traceback
import requests
import json
import questionary
@@ -271,7 +272,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
fg=typer.colors.GREEN,
bold=True,
)
except errors.LLMError as e:
except (errors.LLMError, requests.exceptions.HTTPError) as e:
typer.secho(
f"/summarize failed:\n{e}",
fg=typer.colors.RED,

View File

@@ -3,10 +3,10 @@ import datetime
import uuid
from typing import Optional, List, Tuple, Union
from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC
from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC, MESSAGE_SUMMARY_REQUEST_ACK
from memgpt.utils import get_local_time, printd, count_tokens, validate_date_format, extract_date_from_timestamp
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt.llm_api_tools import create
from memgpt.llm_api.llm_api_tools import create
from memgpt.data_types import Message, Passage, AgentState
from memgpt.embeddings import embedding_model, query_embedding, parse_and_chunk_text
@@ -102,16 +102,22 @@ class CoreMemory(object):
raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
def _format_summary_history(message_history: List[Message]):
# TODO use existing prompt formatters for this (eg ChatML)
return "\n".join([f"{m.role}: {m.text}" for m in message_history])
def summarize_messages(
agent_state: AgentState,
message_sequence_to_summarize,
message_sequence_to_summarize: List[Message],
insert_acknowledgement_assistant_message: bool = True,
):
"""Summarize a message sequence using GPT"""
# we need the context_window
context_window = agent_state.llm_config.context_window
summary_prompt = SUMMARY_PROMPT_SYSTEM
summary_input = str(message_sequence_to_summarize)
summary_input = _format_summary_history(message_sequence_to_summarize)
summary_input_tkns = count_tokens(summary_input)
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_FRAC * context_window:
trunc_ratio = (MESSAGE_SUMMARY_WARNING_FRAC * context_window / summary_input_tkns) * 0.8 # For good measure...
@@ -120,10 +126,14 @@ def summarize_messages(
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])]
+ message_sequence_to_summarize[cutoff:]
)
message_sequence = [
{"role": "system", "content": summary_prompt},
{"role": "user", "content": summary_input},
]
dummy_user_id = uuid.uuid4()
dummy_agent_id = uuid.uuid4()
message_sequence = []
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt))
if insert_acknowledgement_assistant_message:
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="assistant", text=MESSAGE_SUMMARY_REQUEST_ACK))
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input))
response = create(
agent_state=agent_state,

View File

@@ -14,14 +14,46 @@ class UserMessage(BaseModel):
name: Optional[str] = None
class ToolCallFunction(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
type: Literal["function"] = "function"
function: ToolCallFunction
class AssistantMessage(BaseModel):
content: Optional[str] = None
role: str = "assistant"
name: Optional[str] = None
tool_calls: Optional[List] = None
tool_calls: Optional[List[ToolCall]] = None
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage]
class ToolMessage(BaseModel):
content: str
role: str = "tool"
tool_call_id: str
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
def cast_message_to_subtype(m_dict: dict) -> ChatMessage:
"""Cast a dictionary to one of the individual message types"""
role = m_dict.get("role")
if role == "system":
return SystemMessage(**m_dict)
elif role == "user":
return UserMessage(**m_dict)
elif role == "assistant":
return AssistantMessage(**m_dict)
elif role == "tool":
return ToolMessage(**m_dict)
else:
raise ValueError("Unknown message role")
class ResponseFormat(BaseModel):

View File

@@ -17,7 +17,7 @@ from functools import wraps
from typing import get_type_hints, Union, _GenericAlias
from urllib.parse import urlparse
from urllib.parse import urlparse, urljoin
from contextlib import contextmanager
import difflib
import demjson3 as demjson
@@ -469,6 +469,13 @@ NOUN_BANK = [
]
def smart_urljoin(base_url: str, relative_url: str) -> str:
"""urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off"""
if not base_url.endswith("/"):
base_url += "/"
return urljoin(base_url, relative_url)
def is_utc_datetime(dt: datetime) -> bool:
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0)