fix: refactor create(..) call to LLMs to not require AgentState (#1307)

Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
Sarah Wooders
2024-04-28 15:21:20 -07:00
committed by GitHub
parent ac06ef9e22
commit 877dd89f3c
4 changed files with 40 additions and 34 deletions

View File

@@ -424,7 +424,9 @@ class Agent(object):
"""Get response from LLM API"""
try:
response = create(
agent_state=self.agent_state,
# agent_state=self.agent_state,
llm_config=self.agent_state.llm_config,
user_id=self.agent_state.user_id,
messages=message_sequence,
functions=self.functions,
functions_python=self.functions_python,

View File

@@ -31,6 +31,7 @@ def message_chatgpt(self, message: str):
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE),
Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)),
]
# TODO: this will error without an LLMConfig
response = create(
model=MESSAGE_CHATGPT_FUNCTION_MODEL,
messages=message_sequence,

View File

@@ -1,13 +1,14 @@
import os
import random
import time
import uuid
from typing import List, Optional, Union
import requests
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import AgentState, Message
from memgpt.data_types import Message
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
from memgpt.llm_api.azure_openai import (
MODEL_TO_AZURE_ENGINE,
@@ -29,6 +30,7 @@ from memgpt.models.chat_completion_request import (
cast_message_to_subtype,
)
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.pydantic_models import LLMConfigModel
from memgpt.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
@@ -135,8 +137,10 @@ def retry_with_exponential_backoff(
@retry_with_exponential_backoff
def create(
agent_state: AgentState,
# agent_state: AgentState,
llm_config: LLMConfigModel,
messages: List[Message],
user_id: uuid.UUID = None, # option UUID to associate request with
functions: list = None,
functions_python: list = None,
function_call: str = "auto",
@@ -152,7 +156,7 @@ def create(
"""Return response to chat completion with backoff"""
from memgpt.utils import printd
printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}")
printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}")
# TODO eventually refactor so that credentials are passed through
credentials = MemGPTCredentials.load()
@@ -162,26 +166,26 @@ def create(
function_call = None
# openai
if agent_state.llm_config.model_endpoint_type == "openai":
if llm_config.model_endpoint_type == "openai":
# TODO do the same for Azure?
if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1":
if credentials.openai_key is None and llm_config.model_endpoint == "https://api.openai.com/v1":
# only is a problem if we are *not* using an openai proxy
raise ValueError(f"OpenAI key is missing from MemGPT config file")
if use_tool_naming:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)
else:
data = ChatCompletionRequest(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)
if stream:
@@ -190,7 +194,7 @@ def create(
stream_inferface, AgentRefreshStreamingInterface
), type(stream_inferface)
return openai_chat_completions_process_stream(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
stream_inferface=stream_inferface,
@@ -198,17 +202,15 @@ def create(
else:
data.stream = False
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
)
# azure
elif agent_state.llm_config.model_endpoint_type == "azure":
elif llm_config.model_endpoint_type == "azure":
azure_deployment = (
credentials.azure_deployment
if credentials.azure_deployment is not None
else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model]
credentials.azure_deployment if credentials.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[llm_config.model]
)
if use_tool_naming:
data = dict(
@@ -217,7 +219,7 @@ def create(
messages=messages,
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)
else:
data = dict(
@@ -226,7 +228,7 @@ def create(
messages=messages,
functions=functions,
function_call=function_call,
user=str(agent_state.user_id),
user=str(user_id),
)
return azure_openai_chat_completions_request(
resource_name=credentials.azure_endpoint,
@@ -236,7 +238,7 @@ def create(
data=data,
)
elif agent_state.llm_config.model_endpoint_type == "google_ai":
elif llm_config.model_endpoint_type == "google_ai":
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Google AI API requests")
@@ -254,7 +256,7 @@ def create(
return google_ai_chat_completions_request(
inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg,
service_endpoint=credentials.google_ai_service_endpoint,
model=agent_state.llm_config.model,
model=llm_config.model,
api_key=credentials.google_ai_key,
# see structure of payload here: https://ai.google.dev/docs/function_calling
data=dict(
@@ -263,7 +265,7 @@ def create(
),
)
elif agent_state.llm_config.model_endpoint_type == "anthropic":
elif llm_config.model_endpoint_type == "anthropic":
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")
@@ -274,20 +276,20 @@ def create(
tools = None
return anthropic_chat_completions_request(
url=agent_state.llm_config.model_endpoint,
url=llm_config.model_endpoint,
api_key=credentials.anthropic_key,
data=ChatCompletionRequest(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
# tool_choice=function_call,
# user=str(agent_state.user_id),
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
max_tokens=1024, # TODO make dynamic
),
)
elif agent_state.llm_config.model_endpoint_type == "cohere":
elif llm_config.model_endpoint_type == "cohere":
if not use_tool_naming:
raise NotImplementedError("Only tool calling supported on Cohere API requests")
@@ -298,7 +300,7 @@ def create(
tools = None
return cohere_chat_completions_request(
# url=agent_state.llm_config.model_endpoint,
# url=llm_config.model_endpoint,
url="https://api.cohere.ai/v1", # TODO
api_key=os.getenv("COHERE_API_KEY"), # TODO remove
chat_completion_request=ChatCompletionRequest(
@@ -306,7 +308,7 @@ def create(
messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages],
tools=[{"type": "function", "function": f} for f in functions] if functions else None,
tool_choice=function_call,
# user=str(agent_state.user_id),
# user=str(user_id),
# NOTE: max_tokens is required for Anthropic API
# max_tokens=1024, # TODO make dynamic
),
@@ -315,16 +317,16 @@ def create(
# local model
else:
return get_chat_completion(
model=agent_state.llm_config.model,
model=llm_config.model,
messages=messages,
functions=functions,
functions_python=functions_python,
function_call=function_call,
context_window=agent_state.llm_config.context_window,
endpoint=agent_state.llm_config.model_endpoint,
endpoint_type=agent_state.llm_config.model_endpoint_type,
wrapper=agent_state.llm_config.model_wrapper,
user=str(agent_state.user_id),
context_window=llm_config.context_window,
endpoint=llm_config.model_endpoint,
endpoint_type=llm_config.model_endpoint_type,
wrapper=llm_config.model_wrapper,
user=str(user_id),
# hint
first_message=first_message,
# auth-related

View File

@@ -142,7 +142,8 @@ def summarize_messages(
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input))
response = create(
agent_state=agent_state,
llm_config=agent_state.llm_config,
user_id=agent_state.user_id,
messages=message_sequence,
)