From 877dd89f3c5fcd832bc8f418d38680eac118805d Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sun, 28 Apr 2024 15:21:20 -0700 Subject: [PATCH] fix: refactor `create(..)` call to LLMs to not require `AgentState` (#1307) Co-authored-by: cpacker --- memgpt/agent.py | 4 +- memgpt/functions/function_sets/extras.py | 1 + memgpt/llm_api/llm_api_tools.py | 66 ++++++++++++------------ memgpt/memory.py | 3 +- 4 files changed, 40 insertions(+), 34 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index b02b6cd4..22aeeb40 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -424,7 +424,9 @@ class Agent(object): """Get response from LLM API""" try: response = create( - agent_state=self.agent_state, + # agent_state=self.agent_state, + llm_config=self.agent_state.llm_config, + user_id=self.agent_state.user_id, messages=message_sequence, functions=self.functions, functions_python=self.functions_python, diff --git a/memgpt/functions/function_sets/extras.py b/memgpt/functions/function_sets/extras.py index 9eb90988..025c3e6d 100644 --- a/memgpt/functions/function_sets/extras.py +++ b/memgpt/functions/function_sets/extras.py @@ -31,6 +31,7 @@ def message_chatgpt(self, message: str): Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE), Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=str(message)), ] + # TODO: this will error without an LLMConfig response = create( model=MESSAGE_CHATGPT_FUNCTION_MODEL, messages=message_sequence, diff --git a/memgpt/llm_api/llm_api_tools.py b/memgpt/llm_api/llm_api_tools.py index 220bebb1..59abc09e 100644 --- a/memgpt/llm_api/llm_api_tools.py +++ b/memgpt/llm_api/llm_api_tools.py @@ -1,13 +1,14 @@ import os import random import time +import uuid from typing import List, Optional, Union import requests from memgpt.constants import CLI_WARNING_PREFIX from memgpt.credentials import MemGPTCredentials -from memgpt.data_types import AgentState, Message +from memgpt.data_types import Message from memgpt.llm_api.anthropic import anthropic_chat_completions_request from memgpt.llm_api.azure_openai import ( MODEL_TO_AZURE_ENGINE, @@ -29,6 +30,7 @@ from memgpt.models.chat_completion_request import ( cast_message_to_subtype, ) from memgpt.models.chat_completion_response import ChatCompletionResponse +from memgpt.models.pydantic_models import LLMConfigModel from memgpt.streaming_interface import ( AgentChunkStreamingInterface, AgentRefreshStreamingInterface, @@ -135,8 +137,10 @@ def retry_with_exponential_backoff( @retry_with_exponential_backoff def create( - agent_state: AgentState, + # agent_state: AgentState, + llm_config: LLMConfigModel, messages: List[Message], + user_id: uuid.UUID = None, # option UUID to associate request with functions: list = None, functions_python: list = None, function_call: str = "auto", @@ -152,7 +156,7 @@ def create( """Return response to chat completion with backoff""" from memgpt.utils import printd - printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}") + printd(f"Using model {llm_config.model_endpoint_type}, endpoint: {llm_config.model_endpoint}") # TODO eventually refactor so that credentials are passed through credentials = MemGPTCredentials.load() @@ -162,26 +166,26 @@ def create( function_call = None # openai - if agent_state.llm_config.model_endpoint_type == "openai": + if llm_config.model_endpoint_type == "openai": # TODO do the same for Azure? - if credentials.openai_key is None and agent_state.llm_config.model_endpoint == "https://api.openai.com/v1": + if credentials.openai_key is None and llm_config.model_endpoint == "https://api.openai.com/v1": # only is a problem if we are *not* using an openai proxy raise ValueError(f"OpenAI key is missing from MemGPT config file") if use_tool_naming: data = ChatCompletionRequest( - model=agent_state.llm_config.model, + model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, tool_choice=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) else: data = ChatCompletionRequest( - model=agent_state.llm_config.model, + model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], functions=functions, function_call=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) if stream: @@ -190,7 +194,7 @@ def create( stream_inferface, AgentRefreshStreamingInterface ), type(stream_inferface) return openai_chat_completions_process_stream( - url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, chat_completion_request=data, stream_inferface=stream_inferface, @@ -198,17 +202,15 @@ def create( else: data.stream = False return openai_chat_completions_request( - url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions + url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions api_key=credentials.openai_key, chat_completion_request=data, ) # azure - elif agent_state.llm_config.model_endpoint_type == "azure": + elif llm_config.model_endpoint_type == "azure": azure_deployment = ( - credentials.azure_deployment - if credentials.azure_deployment is not None - else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model] + credentials.azure_deployment if credentials.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[llm_config.model] ) if use_tool_naming: data = dict( @@ -217,7 +219,7 @@ def create( messages=messages, tools=[{"type": "function", "function": f} for f in functions] if functions else None, tool_choice=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) else: data = dict( @@ -226,7 +228,7 @@ def create( messages=messages, functions=functions, function_call=function_call, - user=str(agent_state.user_id), + user=str(user_id), ) return azure_openai_chat_completions_request( resource_name=credentials.azure_endpoint, @@ -236,7 +238,7 @@ def create( data=data, ) - elif agent_state.llm_config.model_endpoint_type == "google_ai": + elif llm_config.model_endpoint_type == "google_ai": if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Google AI API requests") @@ -254,7 +256,7 @@ def create( return google_ai_chat_completions_request( inner_thoughts_in_kwargs=google_ai_inner_thoughts_in_kwarg, service_endpoint=credentials.google_ai_service_endpoint, - model=agent_state.llm_config.model, + model=llm_config.model, api_key=credentials.google_ai_key, # see structure of payload here: https://ai.google.dev/docs/function_calling data=dict( @@ -263,7 +265,7 @@ def create( ), ) - elif agent_state.llm_config.model_endpoint_type == "anthropic": + elif llm_config.model_endpoint_type == "anthropic": if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") @@ -274,20 +276,20 @@ def create( tools = None return anthropic_chat_completions_request( - url=agent_state.llm_config.model_endpoint, + url=llm_config.model_endpoint, api_key=credentials.anthropic_key, data=ChatCompletionRequest( - model=agent_state.llm_config.model, + model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, # tool_choice=function_call, - # user=str(agent_state.user_id), + # user=str(user_id), # NOTE: max_tokens is required for Anthropic API max_tokens=1024, # TODO make dynamic ), ) - elif agent_state.llm_config.model_endpoint_type == "cohere": + elif llm_config.model_endpoint_type == "cohere": if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Cohere API requests") @@ -298,7 +300,7 @@ def create( tools = None return cohere_chat_completions_request( - # url=agent_state.llm_config.model_endpoint, + # url=llm_config.model_endpoint, url="https://api.cohere.ai/v1", # TODO api_key=os.getenv("COHERE_API_KEY"), # TODO remove chat_completion_request=ChatCompletionRequest( @@ -306,7 +308,7 @@ def create( messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, tool_choice=function_call, - # user=str(agent_state.user_id), + # user=str(user_id), # NOTE: max_tokens is required for Anthropic API # max_tokens=1024, # TODO make dynamic ), @@ -315,16 +317,16 @@ def create( # local model else: return get_chat_completion( - model=agent_state.llm_config.model, + model=llm_config.model, messages=messages, functions=functions, functions_python=functions_python, function_call=function_call, - context_window=agent_state.llm_config.context_window, - endpoint=agent_state.llm_config.model_endpoint, - endpoint_type=agent_state.llm_config.model_endpoint_type, - wrapper=agent_state.llm_config.model_wrapper, - user=str(agent_state.user_id), + context_window=llm_config.context_window, + endpoint=llm_config.model_endpoint, + endpoint_type=llm_config.model_endpoint_type, + wrapper=llm_config.model_wrapper, + user=str(user_id), # hint first_message=first_message, # auth-related diff --git a/memgpt/memory.py b/memgpt/memory.py index eb2c03ca..b07b0263 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -142,7 +142,8 @@ def summarize_messages( message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input)) response = create( - agent_state=agent_state, + llm_config=agent_state.llm_config, + user_id=agent_state.user_id, messages=message_sequence, )