|
|
|
|
@@ -1,13 +1,14 @@
|
|
|
|
|
import os
|
|
|
|
|
import random
|
|
|
|
|
import time
|
|
|
|
|
import uuid
|
|
|
|
|
from typing import List, Optional, Union
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
|
|
from memgpt.constants import CLI_WARNING_PREFIX
|
|
|
|
|
from memgpt.credentials import MemGPTCredentials
|
|
|
|
|
from memgpt.data_types import AgentState, Message
|
|
|
|
|
from memgpt.data_types import Message
|
|
|
|
|
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
|
|
|
|
|
from memgpt.llm_api.azure_openai import (
|
|
|
|
|
MODEL_TO_AZURE_ENGINE,
|
|
|
|
|
@@ -29,6 +30,7 @@ from memgpt.models.chat_completion_request import (
|
|
|
|
|
cast_message_to_subtype,
|
|
|
|
|
)
|
|
|
|
|
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
|
|
|
|
from memgpt.models.pydantic_models import LLMConfigModel
|
|
|
|
|
from memgpt.streaming_interface import (
|
|
|
|
|
AgentChunkStreamingInterface,
|
|
|
|
|
AgentRefreshStreamingInterface,
|
|
|
|
|
@@ -135,8 +137,10 @@ def retry_with_exponential_backoff(
|
|
|
|
|
|
|
|
|
|
@retry_with_exponential_backoff
|
|
|
|
|
def create(
|
|
|
|
|
agent_state: AgentState,
|
|
|
|
|
# agent_state: AgentState,
|
|
|
|
|
llm_config: LLMConfigModel,
|
|
|
|
|
messages: List[Message],
|
|
|
|
|
user_id: uuid.UUID = None, # option UUID to associate request with
|
|
|
|
|
functions: list = None,
|
|
|
|
|
functions_python: list = None,
|
|
|
|
|
function_call: str = "auto",
|
|
|
|
|
@@ -152,7 +156,7 @@ def create(
|
|
|
|
|
"""Return response to chat completion with backoff"""
|
|
|
|
|
from memgpt.utils import printd
|
|
|
|
|
|
|
|
|
|
printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}")
|
|
|
|
|
printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}")
|
|
|
|
|
|
|
|
|
|
# TODO eventually refactor so that credentials are passed through
|
|
|
|
|
credentials = MemGPTCredentials.load()
|
|
|
|
|
@@ -162,26 +166,26 @@ def create(
|
|
|
|
|
function_call = None
|
|
|
|
|
|
|
|
|
|
# openai
|
|
|
|
|
if agent_state.llm_config.model_endpoint_type == "openai":
|
|
|
|
|
if llm_config.model_endpoint_type == "openai":
|
|
|
|
|
# TODO do the same for Azure?
|
|
|
|
|
if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1":
|
|
|
|
|
if credentials.openai_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
|
|
|
|
|
# only is a problem if we are *not* using an openai proxy
|
|
|
|
|
raise ValueError(f"OpenAI key is missing from MemGPT config file")
|
|
|
|
|
if use_tool_naming:
|
|
|
|
|
data = ChatCompletionRequest(
|
|
|
|
|
model=agent_state.llm_config.model,
|
|
|
|
|
model=llm_config.model,
|
|
|
|
|
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
|
|
|
|
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
|
|
|
|
|
tool_choice=function_call,
|
|
|
|
|
user=str(agent_state.user_id),
|
|
|
|
|
user=str(user_id),
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
data = ChatCompletionRequest(
|
|
|
|
|
model=agent_state.llm_config.model,
|
|
|
|
|
model=llm_config.model,
|
|
|
|
|
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
|
|
|
|
functions=functions,
|
|
|
|
|
function_call=function_call,
|
|
|
|
|
user=str(agent_state.user_id),
|
|
|
|
|
user=str(user_id),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stream:
|
|
|
|
|
@@ -190,7 +194,7 @@ def create(
|
|
|
|
|
stream_inferface, AgentRefreshStreamingInterface
|
|
|
|
|
), type(stream_inferface)
|
|
|
|
|
return openai_chat_completions_process_stream(
|
|
|
|
|
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
|
|
|
|
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
|
|
|
|
api_key=credentials.openai_key,
|
|
|
|
|
chat_completion_request=data,
|
|
|
|
|
stream_inferface=stream_inferface,
|
|
|
|
|
@@ -198,17 +202,15 @@ def create(
|
|
|
|
|
else:
|
|
|
|
|
data.stream = False
|
|
|
|
|
return openai_chat_completions_request(
|
|
|
|
|
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
|
|
|
|
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
|
|
|
|
api_key=credentials.openai_key,
|
|
|
|
|
chat_completion_request=data,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# azure
|
|
|
|
|
elif agent_state.llm_config.model_endpoint_type == "azure":
|
|
|
|
|
elif llm_config.model_endpoint_type == "azure":
|
|
|
|
|
azure_deployment = (
|
|
|
|
|
credentials.azure_deployment
|
|
|
|
|
if credentials.azure_deployment is not None
|
|
|
|
|
else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model]
|
|
|
|
|
credentials.azure_deployment if credentials.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[llm_config.model]
|
|
|
|
|
)
|
|
|
|
|
if use_tool_naming:
|
|
|
|
|
data = dict(
|
|
|
|
|
@@ -217,7 +219,7 @@ def create(
|
|
|
|
|
messages=messages,
|
|
|
|
|
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
|
|
|
|
|
tool_choice=function_call,
|
|
|
|
|
user=str(agent_state.user_id),
|
|
|
|
|
user=str(user_id),
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
data = dict(
|
|
|
|
|
@@ -226,7 +228,7 @@ def create(
|
|
|
|
|
messages=messages,
|
|
|
|
|
functions=functions,
|
|
|
|
|
function_call=function_call,
|
|
|
|
|
user=str(agent_state.user_id),
|
|
|
|
|
user=str(user_id),
|
|
|
|
|
)
|
|
|
|
|
return azure_openai_chat_completions_request(
|
|
|
|
|
resource_name=credentials.azure_endpoint,
|
|
|
|
|
@@ -236,7 +238,7 @@ def create(
|
|
|
|
|
data=data,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif agent_state.llm_config.model_endpoint_type == "google_ai":
|
|
|
|
|
elif llm_config.model_endpoint_type == "google_ai":
|
|
|
|
|
if not use_tool_naming:
|
|
|
|
|
raise NotImplementedError("Only tool calling supported on Google AI API requests")
|
|
|
|
|
|
|
|
|
|
@@ -254,7 +256,7 @@ def create(
|
|
|
|
|
return google_ai_chat_completions_request(
|
|
|
|
|
inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg,
|
|
|
|
|
service_endpoint=credentials.google_ai_service_endpoint,
|
|
|
|
|
model=agent_state.llm_config.model,
|
|
|
|
|
model=llm_config.model,
|
|
|
|
|
api_key=credentials.google_ai_key,
|
|
|
|
|
# see structure of payload here: https://ai.google.dev/docs/function_calling
|
|
|
|
|
data=dict(
|
|
|
|
|
@@ -263,7 +265,7 @@ def create(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif agent_state.llm_config.model_endpoint_type == "anthropic":
|
|
|
|
|
elif llm_config.model_endpoint_type == "anthropic":
|
|
|
|
|
if not use_tool_naming:
|
|
|
|
|
raise NotImplementedError("Only tool calling supported on Anthropic API requests")
|
|
|
|
|
|
|
|
|
|
@@ -274,20 +276,20 @@ def create(
|
|
|
|
|
tools = None
|
|
|
|
|
|
|
|
|
|
return anthropic_chat_completions_request(
|
|
|
|
|
url=agent_state.llm_config.model_endpoint,
|
|
|
|
|
url=llm_config.model_endpoint,
|
|
|
|
|
api_key=credentials.anthropic_key,
|
|
|
|
|
data=ChatCompletionRequest(
|
|
|
|
|
model=agent_state.llm_config.model,
|
|
|
|
|
model=llm_config.model,
|
|
|
|
|
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
|
|
|
|
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
|
|
|
|
|
# tool_choice=function_call,
|
|
|
|
|
# user=str(agent_state.user_id),
|
|
|
|
|
# user=str(user_id),
|
|
|
|
|
# NOTE: max_tokens is required for Anthropic API
|
|
|
|
|
max_tokens=1024, # TODO make dynamic
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif agent_state.llm_config.model_endpoint_type == "cohere":
|
|
|
|
|
elif llm_config.model_endpoint_type == "cohere":
|
|
|
|
|
if not use_tool_naming:
|
|
|
|
|
raise NotImplementedError("Only tool calling supported on Cohere API requests")
|
|
|
|
|
|
|
|
|
|
@@ -298,7 +300,7 @@ def create(
|
|
|
|
|
tools = None
|
|
|
|
|
|
|
|
|
|
return cohere_chat_completions_request(
|
|
|
|
|
# url=agent_state.llm_config.model_endpoint,
|
|
|
|
|
# url=llm_config.model_endpoint,
|
|
|
|
|
url="https://api.cohere.ai/v1", # TODO
|
|
|
|
|
api_key=os.getenv("COHERE_API_KEY"), # TODO remove
|
|
|
|
|
chat_completion_request=ChatCompletionRequest(
|
|
|
|
|
@@ -306,7 +308,7 @@ def create(
|
|
|
|
|
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
|
|
|
|
|
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
|
|
|
|
|
tool_choice=function_call,
|
|
|
|
|
# user=str(agent_state.user_id),
|
|
|
|
|
# user=str(user_id),
|
|
|
|
|
# NOTE: max_tokens is required for Anthropic API
|
|
|
|
|
# max_tokens=1024, # TODO make dynamic
|
|
|
|
|
),
|
|
|
|
|
@@ -315,16 +317,16 @@ def create(
|
|
|
|
|
# local model
|
|
|
|
|
else:
|
|
|
|
|
return get_chat_completion(
|
|
|
|
|
model=agent_state.llm_config.model,
|
|
|
|
|
model=llm_config.model,
|
|
|
|
|
messages=messages,
|
|
|
|
|
functions=functions,
|
|
|
|
|
functions_python=functions_python,
|
|
|
|
|
function_call=function_call,
|
|
|
|
|
context_window=agent_state.llm_config.context_window,
|
|
|
|
|
endpoint=agent_state.llm_config.model_endpoint,
|
|
|
|
|
endpoint_type=agent_state.llm_config.model_endpoint_type,
|
|
|
|
|
wrapper=agent_state.llm_config.model_wrapper,
|
|
|
|
|
user=str(agent_state.user_id),
|
|
|
|
|
context_window=llm_config.context_window,
|
|
|
|
|
endpoint=llm_config.model_endpoint,
|
|
|
|
|
endpoint_type=llm_config.model_endpoint_type,
|
|
|
|
|
wrapper=llm_config.model_wrapper,
|
|
|
|
|
user=str(user_id),
|
|
|
|
|
# hint
|
|
|
|
|
first_message=first_message,
|
|
|
|
|
# auth-related
|
|
|
|
|
|