feat: add tools from the Python client (#1463)
This commit is contained in:
120
memgpt/agent.py
120
memgpt/agent.py
@@ -21,14 +21,7 @@ from memgpt.constants import (
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
)
|
||||
from memgpt.data_types import (
|
||||
AgentState,
|
||||
EmbeddingConfig,
|
||||
LLMConfig,
|
||||
Message,
|
||||
Passage,
|
||||
Preset,
|
||||
)
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
|
||||
from memgpt.memory import ArchivalMemory
|
||||
@@ -36,6 +29,7 @@ from memgpt.memory import CoreMemory as InContextMemory
|
||||
from memgpt.memory import RecallMemory, summarize_messages
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.system import (
|
||||
get_initial_boot_messages,
|
||||
@@ -45,7 +39,6 @@ from memgpt.system import (
|
||||
)
|
||||
from memgpt.utils import (
|
||||
count_tokens,
|
||||
create_random_username,
|
||||
create_uuid_from_string,
|
||||
get_local_time,
|
||||
get_schema_diff,
|
||||
@@ -200,76 +193,44 @@ class Agent(object):
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
# agents can be created from providing agent_state
|
||||
agent_state: Optional[AgentState] = None,
|
||||
# or from providing a preset (requires preset + extra fields)
|
||||
preset: Optional[Preset] = None,
|
||||
created_by: Optional[uuid.UUID] = None,
|
||||
name: Optional[str] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_state: AgentState,
|
||||
tools: List[ToolModel],
|
||||
# extras
|
||||
messages_total: Optional[int] = None, # TODO remove?
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
):
|
||||
# An agent can be created from a Preset object
|
||||
if preset is not None:
|
||||
assert agent_state is None, "Can create an agent from a Preset or AgentState (but both were provided)"
|
||||
assert created_by is not None, "Must provide created_by field when creating an Agent from a Preset"
|
||||
assert llm_config is not None, "Must provide llm_config field when creating an Agent from a Preset"
|
||||
assert embedding_config is not None, "Must provide embedding_config field when creating an Agent from a Preset"
|
||||
|
||||
# if agent_state is also provided, override any preset values
|
||||
init_agent_state = AgentState(
|
||||
name=name if name else create_random_username(),
|
||||
user_id=created_by,
|
||||
persona=preset.persona,
|
||||
human=preset.human,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
preset=preset.name, # TODO link via preset.id instead of name?
|
||||
state={
|
||||
"persona": preset.persona,
|
||||
"human": preset.human,
|
||||
"system": preset.system,
|
||||
"functions": preset.functions_schema,
|
||||
"messages": None,
|
||||
},
|
||||
)
|
||||
|
||||
# An agent can also be created directly from AgentState
|
||||
elif agent_state is not None:
|
||||
assert preset is None, "Can create an agent from a Preset or AgentState (but both were provided)"
|
||||
assert agent_state.state is not None and agent_state.state != {}, "AgentState.state cannot be empty"
|
||||
|
||||
# Assume the agent_state passed in is formatted correctly
|
||||
init_agent_state = agent_state
|
||||
|
||||
else:
|
||||
raise ValueError("Both Preset and AgentState were null (must provide one or the other)")
|
||||
# tools
|
||||
for tool in tools:
|
||||
assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools"
|
||||
for tool_name in agent_state.tools:
|
||||
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
|
||||
# Store the functions schemas (this is passed as an argument to ChatCompletion)
|
||||
self.functions = []
|
||||
self.functions_python = {}
|
||||
env = {}
|
||||
env.update(globals())
|
||||
for tool in tools:
|
||||
# WARNING: name may not be consistent?
|
||||
exec(tool.module, env)
|
||||
self.functions_python[tool.name] = env[tool.name]
|
||||
self.functions.append(tool.json_schema)
|
||||
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
|
||||
|
||||
# Hold a copy of the state that was used to init the agent
|
||||
self.agent_state = init_agent_state
|
||||
self.agent_state = agent_state
|
||||
|
||||
# gpt-4, gpt-3.5-turbo, ...
|
||||
self.model = self.agent_state.llm_config.model
|
||||
|
||||
# Store the system instructions (used to rebuild memory)
|
||||
if "system" not in self.agent_state.state:
|
||||
raise ValueError("'system' not found in provided AgentState")
|
||||
self.system = self.agent_state.state["system"]
|
||||
|
||||
if "functions" not in self.agent_state.state:
|
||||
raise ValueError(f"'functions' not found in provided AgentState")
|
||||
# Store the functions schemas (this is passed as an argument to ChatCompletion)
|
||||
self.functions = self.agent_state.state["functions"] # these are the schema
|
||||
# Link the actual python functions corresponding to the schemas
|
||||
self.functions_python = {k: v["python_function"] for k, v in link_functions(function_schemas=self.functions).items()}
|
||||
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
|
||||
self.system = self.agent_state.system
|
||||
|
||||
# Initialize the memory object
|
||||
if "persona" not in self.agent_state.state:
|
||||
# TODO: support more general memory types
|
||||
if "persona" not in self.agent_state.state: # TODO: remove
|
||||
raise ValueError(f"'persona' not found in provided AgentState")
|
||||
if "human" not in self.agent_state.state:
|
||||
if "human" not in self.agent_state.state: # TODO: remove
|
||||
raise ValueError(f"'human' not found in provided AgentState")
|
||||
self.memory = initialize_memory(ai_notes=self.agent_state.state["persona"], human_notes=self.agent_state.state["human"])
|
||||
|
||||
@@ -283,7 +244,6 @@ class Agent(object):
|
||||
self.interface = interface
|
||||
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
# TODO
|
||||
self.persistence_manager = LocalStateManager(agent_state=self.agent_state)
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
@@ -1053,25 +1013,41 @@ class Agent(object):
|
||||
# self.ms.update_agent(agent=new_agent_state)
|
||||
|
||||
def update_state(self) -> AgentState:
|
||||
updated_state = {
|
||||
# updated_state = {
|
||||
# "persona": self.memory.persona,
|
||||
# "human": self.memory.human,
|
||||
# "system": self.system,
|
||||
# "functions": self.functions,
|
||||
# "messages": [str(msg.id) for msg in self._messages],
|
||||
# }
|
||||
memory = {
|
||||
"system": self.system,
|
||||
"persona": self.memory.persona,
|
||||
"human": self.memory.human,
|
||||
"system": self.system,
|
||||
"functions": self.functions,
|
||||
"messages": [str(msg.id) for msg in self._messages],
|
||||
"messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids
|
||||
}
|
||||
|
||||
# TODO: add this field
|
||||
metadata = { # TODO
|
||||
"human_name": self.agent_state.persona,
|
||||
"persona_name": self.agent_state.human,
|
||||
}
|
||||
|
||||
self.agent_state = AgentState(
|
||||
name=self.agent_state.name,
|
||||
user_id=self.agent_state.user_id,
|
||||
persona=self.agent_state.persona,
|
||||
human=self.agent_state.human,
|
||||
tools=self.agent_state.tools,
|
||||
system=self.system,
|
||||
persona=self.agent_state.persona, # TODO: remove
|
||||
human=self.agent_state.human, # TODO: remove
|
||||
## "model_state"
|
||||
llm_config=self.agent_state.llm_config,
|
||||
embedding_config=self.agent_state.embedding_config,
|
||||
preset=self.agent_state.preset,
|
||||
id=self.agent_state.id,
|
||||
created_at=self.agent_state.created_at,
|
||||
state=updated_state,
|
||||
## "agent_state"
|
||||
state=memory,
|
||||
)
|
||||
return self.agent_state
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from memgpt.cli.cli_config import configure
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig, User
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, User
|
||||
from memgpt.log import get_logger
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
@@ -582,9 +582,10 @@ def run(
|
||||
|
||||
# Update the agent with any overrides
|
||||
ms.update_agent(agent_state)
|
||||
tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools]
|
||||
|
||||
# create agent
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface())
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
|
||||
|
||||
else: # create new agent
|
||||
# create new agent config: override defaults with args if provided
|
||||
@@ -650,13 +651,25 @@ def run(
|
||||
typer.secho(f"-> 🤖 Using persona profile: '{preset_obj.persona_name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{preset_obj.human_name}'", fg=typer.colors.WHITE)
|
||||
|
||||
memgpt_agent = Agent(
|
||||
interface=interface(),
|
||||
agent_state = AgentState(
|
||||
name=agent_name,
|
||||
created_by=user.id,
|
||||
preset=preset_obj,
|
||||
user_id=user.id,
|
||||
tools=list([schema["name"] for schema in preset_obj.functions_schema]),
|
||||
system=preset_obj.system,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
human=preset_obj.human,
|
||||
persona=preset_obj.persona,
|
||||
preset=preset_obj.name,
|
||||
state={"messages": None, "persona": preset_obj.persona, "human": preset_obj.human},
|
||||
)
|
||||
print("tools", agent_state.tools)
|
||||
tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools]
|
||||
|
||||
memgpt_agent = Agent(
|
||||
interface=interface(),
|
||||
agent_state=agent_state,
|
||||
tools=tools,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ from memgpt.server.rest_api.admin.users import (
|
||||
GetAllUsersResponse,
|
||||
GetAPIKeysResponse,
|
||||
)
|
||||
from memgpt.server.rest_api.tools.index import ListToolsResponse
|
||||
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
|
||||
|
||||
|
||||
class Admin:
|
||||
@@ -84,11 +84,40 @@ class Admin:
|
||||
self.delete_key(key)
|
||||
self.delete_user(user["user_id"])
|
||||
|
||||
# tools (currently only available for admin)
|
||||
def create_tool(self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None) -> ToolModel:
|
||||
"""Add a tool implemented in a file path"""
|
||||
source_code = open(file_path, "r", encoding="utf-8").read()
|
||||
data = {"name": name, "source_code": source_code, "source_type": source_type, "tags": tags}
|
||||
def create_tool(
|
||||
self,
|
||||
func,
|
||||
name: Optional[str] = None,
|
||||
update: Optional[bool] = True, # TODO: actually use this
|
||||
tags: Optional[List[str]] = None,
|
||||
):
|
||||
"""Create a tool
|
||||
|
||||
Args:
|
||||
func (callable): The function to create a tool for.
|
||||
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
|
||||
update (bool, optional): Update the tool if it already exists. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tool object
|
||||
"""
|
||||
import inspect
|
||||
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
|
||||
# TODO: check if tool already exists
|
||||
# TODO: how to load modules?
|
||||
# parse source code/schema
|
||||
source_code = inspect.getsource(func)
|
||||
json_schema = generate_schema(func, name)
|
||||
source_type = "python"
|
||||
tool_name = json_schema["name"]
|
||||
|
||||
# create data
|
||||
data = {"name": tool_name, "source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
|
||||
CreateToolRequest(**data) # validate data:w
|
||||
|
||||
# make REST request
|
||||
response = requests.post(f"{self.base_url}/admin/tools", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
@@ -96,7 +125,7 @@ class Admin:
|
||||
|
||||
def list_tools(self) -> ListToolsResponse:
|
||||
response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers)
|
||||
return ListToolsResponse(**response.json())
|
||||
return ListToolsResponse(**response.json()).tools
|
||||
|
||||
def delete_tool(self, name: str):
|
||||
response = requests.delete(f"{self.base_url}/admin/tools/{name}", headers=self.headers)
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import requests
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.constants import BASE_TOOLS, DEFAULT_PRESET
|
||||
from memgpt.data_sources.connectors import DataConnector
|
||||
from memgpt.data_types import (
|
||||
AgentState,
|
||||
@@ -26,6 +26,7 @@ from memgpt.models.pydantic_models import (
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
)
|
||||
from memgpt.presets.presets import load_module_tools
|
||||
|
||||
# import pydantic response objects from memgpt.server.rest_api
|
||||
from memgpt.server.rest_api.agents.command import CommandResponse
|
||||
@@ -52,7 +53,6 @@ from memgpt.server.rest_api.presets.index import (
|
||||
ListPresetsResponse,
|
||||
)
|
||||
from memgpt.server.rest_api.sources.index import ListSourcesResponse
|
||||
from memgpt.server.rest_api.tools.index import CreateToolResponse
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
|
||||
@@ -186,12 +186,6 @@ class AbstractClient(object):
|
||||
"""List all tools."""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_tool(
|
||||
self, name: str, file_path: str, source_type: Optional[str] = "python", tags: Optional[List[str]] = None
|
||||
) -> CreateToolResponse:
|
||||
"""Create a tool."""
|
||||
raise NotImplementedError
|
||||
|
||||
# data sources
|
||||
|
||||
def list_sources(self):
|
||||
@@ -266,9 +260,32 @@ class RESTClient(AbstractClient):
|
||||
human: Optional[str] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
# tools
|
||||
tools: Optional[List[str]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
) -> AgentState:
|
||||
"""
|
||||
Create an agent
|
||||
|
||||
Args:
|
||||
name (str): Name of the agent
|
||||
tools (List[str]): List of tools (by name) to attach to the agent
|
||||
include_base_tools (bool): Whether to include base tools (default: `True`)
|
||||
|
||||
Returns:
|
||||
agent_state (AgentState): State of the the created agent.
|
||||
|
||||
"""
|
||||
if embedding_config or llm_config:
|
||||
raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API")
|
||||
|
||||
# construct list of tools
|
||||
tool_names = []
|
||||
if tools:
|
||||
tool_names += tools
|
||||
if include_base_tools:
|
||||
tool_names += BASE_TOOLS
|
||||
|
||||
# TODO: distinguish between name and objects
|
||||
payload = {
|
||||
"config": {
|
||||
@@ -276,6 +293,7 @@ class RESTClient(AbstractClient):
|
||||
"preset": preset,
|
||||
"persona": persona,
|
||||
"human": human,
|
||||
"function_names": tool_names,
|
||||
}
|
||||
}
|
||||
response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers)
|
||||
@@ -310,6 +328,8 @@ class RESTClient(AbstractClient):
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=response.agent_state.state,
|
||||
system=response.agent_state.system,
|
||||
tools=response.agent_state.tools,
|
||||
# load datetime from timestampe
|
||||
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc),
|
||||
)
|
||||
@@ -657,6 +677,8 @@ class LocalClient(AbstractClient):
|
||||
if name and self.agent_exists(agent_name=name):
|
||||
raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})")
|
||||
|
||||
# TODO: implement tools support
|
||||
|
||||
self.interface.clear()
|
||||
agent_state = self.server.create_agent(
|
||||
user_id=self.user_id,
|
||||
@@ -664,6 +686,7 @@ class LocalClient(AbstractClient):
|
||||
preset=preset,
|
||||
persona=persona,
|
||||
human=human,
|
||||
tools=[tool.name for tool in load_module_tools()],
|
||||
)
|
||||
return agent_state
|
||||
|
||||
|
||||
@@ -181,6 +181,7 @@ class MemGPTConfig:
|
||||
# create new config
|
||||
anon_clientid = MemGPTConfig.generate_uuid()
|
||||
config = cls(anon_clientid=anon_clientid, config_path=config_path)
|
||||
|
||||
config.create_config_dir() # create dirs
|
||||
|
||||
return config
|
||||
|
||||
@@ -21,6 +21,18 @@ DEFAULT_PERSONA = "sam_pov"
|
||||
DEFAULT_HUMAN = "basic"
|
||||
DEFAULT_PRESET = "memgpt_chat"
|
||||
|
||||
# Tools
|
||||
BASE_TOOLS = [
|
||||
"send_message",
|
||||
"core_memory_replace",
|
||||
"core_memory_append",
|
||||
"pause_heartbeats",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
"archival_memory_search",
|
||||
]
|
||||
|
||||
# LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level
|
||||
LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET}
|
||||
|
||||
|
||||
@@ -754,11 +754,16 @@ class AgentState:
|
||||
self,
|
||||
name: str,
|
||||
user_id: uuid.UUID,
|
||||
persona: str, # the filename where the persona was originally sourced from
|
||||
human: str, # the filename where the human was originally sourced from
|
||||
# tools
|
||||
tools: List[str], # list of tools by name
|
||||
# system prompt
|
||||
system: str,
|
||||
# config
|
||||
persona: str, # the filename where the persona was originally sourced from # TODO: remove
|
||||
human: str, # the filename where the human was originally sourced from # TODO: remove
|
||||
llm_config: LLMConfig,
|
||||
embedding_config: EmbeddingConfig,
|
||||
preset: str,
|
||||
preset: str, # TODO: remove
|
||||
# (in-context) state contains:
|
||||
# persona: str # the current persona text
|
||||
# human: str # the current human text
|
||||
@@ -768,6 +773,8 @@ class AgentState:
|
||||
id: Optional[uuid.UUID] = None,
|
||||
state: Optional[dict] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
# messages (TODO: implement this)
|
||||
# _metadata: Optional[dict] = None,
|
||||
):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
@@ -779,6 +786,7 @@ class AgentState:
|
||||
# TODO(swooders) we need to handle the case where name is None here
|
||||
# in AgentConfig we autogenerate a name, not sure what the correct thing w/ DBs is, what about NounAdjective combos? Like giphy does? BoredGiraffe etc
|
||||
self.name = name
|
||||
assert self.name, f"AgentState name must be a non-empty string"
|
||||
self.user_id = user_id
|
||||
self.preset = preset
|
||||
# The INITIAL values of the persona and human
|
||||
@@ -794,6 +802,15 @@ class AgentState:
|
||||
# state
|
||||
self.state = {} if not state else state
|
||||
|
||||
# tools
|
||||
self.tools = tools
|
||||
|
||||
# system
|
||||
self.system = system
|
||||
|
||||
# metadata
|
||||
# self._metadata = _metadata
|
||||
|
||||
|
||||
class Source:
|
||||
def __init__(
|
||||
|
||||
@@ -28,6 +28,7 @@ def load_function_set(module: ModuleType) -> dict:
|
||||
|
||||
generated_schema = generate_schema(attr)
|
||||
function_dict[attr_name] = {
|
||||
"module": inspect.getsource(module),
|
||||
"python_function": attr,
|
||||
"json_schema": generated_schema,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
import typing
|
||||
from typing import get_args, get_origin
|
||||
from typing import Optional, get_args, get_origin
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
@@ -83,7 +83,7 @@ def pydantic_model_to_open_ai(model):
|
||||
}
|
||||
|
||||
|
||||
def generate_schema(function):
|
||||
def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None):
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
|
||||
@@ -92,11 +92,13 @@ def generate_schema(function):
|
||||
|
||||
# Prepare the schema dictionary
|
||||
schema = {
|
||||
"name": function.__name__,
|
||||
"description": docstring.short_description,
|
||||
"name": function.__name__ if name is None else name,
|
||||
"description": docstring.short_description if description is None else description,
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
# TODO: ensure that 'agent' keyword is reserved for `Agent` class
|
||||
|
||||
for param in sig.parameters.values():
|
||||
# Exclude 'self' parameter
|
||||
if param.name == "self":
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
""" Metadata store for user/agent/data_source information"""
|
||||
|
||||
import inspect as python_inspect
|
||||
import os
|
||||
import secrets
|
||||
import traceback
|
||||
@@ -35,7 +34,6 @@ from memgpt.data_types import (
|
||||
Token,
|
||||
User,
|
||||
)
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
from memgpt.models.pydantic_models import (
|
||||
HumanModel,
|
||||
JobModel,
|
||||
@@ -179,6 +177,7 @@ class AgentModel(Base):
|
||||
name = Column(String, nullable=False)
|
||||
persona = Column(String)
|
||||
human = Column(String)
|
||||
system = Column(String)
|
||||
preset = Column(String)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
@@ -189,6 +188,9 @@ class AgentModel(Base):
|
||||
# state
|
||||
state = Column(JSON)
|
||||
|
||||
# tools
|
||||
tools = Column(JSON)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Agent(id='{self.id}', name='{self.name}')>"
|
||||
|
||||
@@ -204,6 +206,8 @@ class AgentModel(Base):
|
||||
llm_config=self.llm_config,
|
||||
embedding_config=self.embedding_config,
|
||||
state=self.state,
|
||||
tools=self.tools,
|
||||
system=self.system,
|
||||
)
|
||||
|
||||
|
||||
@@ -545,6 +549,13 @@ class MetadataStore:
|
||||
session.commit()
|
||||
session.refresh(persona)
|
||||
|
||||
@enforce_types
|
||||
def update_tool(self, tool: ToolModel):
|
||||
with self.session_maker() as session:
|
||||
session.add(tool)
|
||||
session.commit()
|
||||
session.refresh(tool)
|
||||
|
||||
@enforce_types
|
||||
def delete_agent(self, agent_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
@@ -595,20 +606,8 @@ class MetadataStore:
|
||||
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||
def list_tools(self) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
available_functions = load_all_function_sets()
|
||||
results = [
|
||||
ToolModel(
|
||||
name=k,
|
||||
json_schema=v["json_schema"],
|
||||
tags=v["tags"],
|
||||
source_type="python",
|
||||
source_code=python_inspect.getsource(v["python_function"]),
|
||||
)
|
||||
for k, v in available_functions.items()
|
||||
]
|
||||
results = session.query(ToolModel).all()
|
||||
return results
|
||||
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
# return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# tool imports
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -48,11 +49,12 @@ class PresetModel(BaseModel):
|
||||
|
||||
class ToolModel(SQLModel, table=True):
|
||||
# TODO move into database
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True)
|
||||
name: str = Field(..., description="The name of the function.", primary_key=True)
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.")
|
||||
tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
source_code: Optional[str] = Field(..., description="The source code of the function.")
|
||||
module: Optional[str] = Field(None, description="The module of the function.")
|
||||
|
||||
json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")
|
||||
|
||||
@@ -89,7 +91,9 @@ class AgentStateModel(BaseModel):
|
||||
preset: str = Field(..., description="The preset used by the agent.")
|
||||
persona: str = Field(..., description="The persona used by the agent.")
|
||||
human: str = Field(..., description="The human used by the agent.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
|
||||
tools: List[str] = Field(..., description="The tools used by the agent.")
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
# functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
|
||||
|
||||
# llm information
|
||||
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
@@ -23,8 +24,8 @@ available_presets = load_all_presets()
|
||||
preset_options = list(available_presets.keys())
|
||||
|
||||
|
||||
def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
|
||||
module_name = "base"
|
||||
def load_module_tools(module_name="base"):
|
||||
# return List[ToolModel] from base.py tools
|
||||
full_module_name = f"memgpt.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
@@ -42,8 +43,29 @@ def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
|
||||
printd(err)
|
||||
|
||||
# create tool in db
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
ms.add_tool(ToolModel(name=name, tags=["base"], source_type="python", json_schema=schema["json_schema"]))
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tools.append(
|
||||
ToolModel(
|
||||
name=name,
|
||||
tags=["base"],
|
||||
source_type="python",
|
||||
module=schema["module"],
|
||||
source_code=source_code,
|
||||
json_schema=schema["json_schema"],
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
|
||||
def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
|
||||
module_name = "base"
|
||||
for tool in load_module_tools(module_name=module_name):
|
||||
existing_tool = ms.get_tool(tool.name)
|
||||
if not existing_tool:
|
||||
ms.add_tool(tool)
|
||||
|
||||
|
||||
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
|
||||
|
||||
@@ -15,7 +15,7 @@ class ListToolsResponse(BaseModel):
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
json_schema: dict = Field(..., description="JSON schema of the tool.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
@@ -74,28 +74,13 @@ def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
|
||||
# user_id: uuid.UUID = Depends(get_current_user_with_server), # TODO: add back when user-specific
|
||||
):
|
||||
"""
|
||||
Create a new tool (dummy route)
|
||||
Create a new tool
|
||||
"""
|
||||
from memgpt.functions.functions import load_function_file, write_function
|
||||
|
||||
# check if function already exists
|
||||
if server.ms.get_tool(request.name):
|
||||
raise ValueError(f"Tool with name {request.name} already exists.")
|
||||
|
||||
# write function to ~/.memgt/functions directory
|
||||
file_path = write_function(request.name, request.name, request.source_code)
|
||||
|
||||
# TODO: Use load_function_file to load function schema
|
||||
schema = load_function_file(file_path)
|
||||
assert len(list(schema.keys())) == 1, "Function schema must have exactly one key"
|
||||
json_schema = list(schema.values())[0]["json_schema"]
|
||||
|
||||
print("adding tool", request.name, request.tags, request.source_code)
|
||||
tool = ToolModel(name=request.name, json_schema=json_schema, tags=request.tags, source_code=request.source_code)
|
||||
tool.id
|
||||
server.ms.add_tool(tool)
|
||||
|
||||
# TODO: insert tool information into DB as ToolModel
|
||||
return server.ms.get_tool(request.name)
|
||||
try:
|
||||
return server.create_tool(
|
||||
json_schema=request.json_schema, source_code=request.source_code, source_type=request.source_type, tags=request.tags
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
|
||||
|
||||
return router
|
||||
|
||||
@@ -86,7 +86,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||
tools=agent_state.tools,
|
||||
system=agent_state.system,
|
||||
),
|
||||
last_run_at=None, # TODO
|
||||
sources=attached_sources,
|
||||
@@ -131,7 +132,8 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||
tools=agent_state.tools,
|
||||
system=agent_state.system,
|
||||
),
|
||||
last_run_at=None, # TODO
|
||||
sources=attached_sources,
|
||||
|
||||
@@ -14,6 +14,7 @@ from memgpt.models.pydantic_models import (
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -26,6 +27,7 @@ class ListAgentsResponse(BaseModel):
|
||||
|
||||
|
||||
class CreateAgentRequest(BaseModel):
|
||||
# TODO: modify this (along with front end)
|
||||
config: dict = Field(..., description="The agent configuration object.")
|
||||
|
||||
|
||||
@@ -60,27 +62,42 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
"""
|
||||
interface.clear()
|
||||
|
||||
# Parse request
|
||||
# TODO: don't just use JSON in the future
|
||||
human_name = request.config["human_name"] if "human_name" in request.config else None
|
||||
human = request.config["human"] if "human" in request.config else None
|
||||
persona_name = request.config["persona_name"] if "persona_name" in request.config else None
|
||||
persona = request.config["persona"] if "persona" in request.config else None
|
||||
preset = request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
|
||||
tool_names = request.config["function_names"]
|
||||
|
||||
print("PRESET", preset)
|
||||
|
||||
try:
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
# **request.config
|
||||
# TODO turn into a pydantic model
|
||||
name=request.config["name"],
|
||||
preset=request.config["preset"] if "preset" in request.config else None,
|
||||
persona_name=request.config["persona_name"] if "persona_name" in request.config else None,
|
||||
human_name=request.config["human_name"] if "human_name" in request.config else None,
|
||||
persona=request.config["persona"] if "persona" in request.config else None,
|
||||
human=request.config["human"] if "human" in request.config else None,
|
||||
preset=preset,
|
||||
persona_name=persona_name,
|
||||
human_name=human_name,
|
||||
persona=persona,
|
||||
human=human,
|
||||
# llm_config=LLMConfigModel(
|
||||
# model=request.config['model'],
|
||||
# )
|
||||
function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
|
||||
# tools
|
||||
tools=tool_names,
|
||||
# function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
|
||||
)
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
# TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line
|
||||
# TODO: remove
|
||||
preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id)
|
||||
print("SYSTEM", agent_state.system)
|
||||
|
||||
return CreateAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
@@ -94,7 +111,8 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||
tools=tool_names,
|
||||
system=agent_state.system,
|
||||
),
|
||||
preset=PresetModel(
|
||||
name=preset.name,
|
||||
|
||||
@@ -46,6 +46,7 @@ from memgpt.models.pydantic_models import (
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
)
|
||||
from memgpt.utils import create_random_username
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -332,7 +333,8 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Instantiate an agent object using the state retrieved
|
||||
logger.info(f"Creating an agent object")
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface)
|
||||
tool_objs = [self.ms.get_tool(name) for name in agent_state.tools] # get tool objects
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
||||
|
||||
# Add the agent to the in-memory store and return its reference
|
||||
logger.info(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
|
||||
@@ -645,17 +647,22 @@ class SyncServer(LockingServer):
|
||||
def create_agent(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
tools: List[str], # list of tool names (handles) to include
|
||||
# system: str, # system prompt
|
||||
metadata: Optional[dict] = {}, # includes human/persona names
|
||||
name: Optional[str] = None,
|
||||
preset: Optional[str] = None,
|
||||
persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
persona_name: Optional[str] = None,
|
||||
human_name: Optional[str] = None,
|
||||
preset: Optional[str] = None, # TODO: remove eventually
|
||||
# model config
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
# persistence_manager: Union[PersistenceManager, None] = None,
|
||||
function_names: Optional[List[str]] = None, # TODO remove
|
||||
# TODO: refactor this to be a more general memory configuration
|
||||
system: Optional[str] = None, # prompt value
|
||||
persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
persona_name: Optional[str] = None, # TODO: remove
|
||||
human_name: Optional[str] = None, # TODO: remove
|
||||
) -> AgentState:
|
||||
"""Create a new agent using a config"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
@@ -668,6 +675,9 @@ class SyncServer(LockingServer):
|
||||
# if persistence_manager is None:
|
||||
# persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config)
|
||||
|
||||
if name is None:
|
||||
name = create_random_username()
|
||||
|
||||
logger.debug(f"Attempting to find user: {user_id}")
|
||||
user = self.ms.get_user(user_id=user_id)
|
||||
if not user:
|
||||
@@ -684,6 +694,14 @@ class SyncServer(LockingServer):
|
||||
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
|
||||
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
|
||||
|
||||
# system prompt
|
||||
if system is None:
|
||||
system = preset_obj.system
|
||||
else:
|
||||
preset_obj.system = system
|
||||
preset_override = True
|
||||
print("system", preset_obj.system, system)
|
||||
|
||||
# Overwrite fields in the preset if they were specified
|
||||
if human is not None and human != preset_obj.human:
|
||||
preset_override = True
|
||||
@@ -718,15 +736,8 @@ class SyncServer(LockingServer):
|
||||
llm_config = llm_config if llm_config else self.server_llm_config
|
||||
embedding_config = embedding_config if embedding_config else self.server_embedding_config
|
||||
|
||||
# TODO remove (https://github.com/cpacker/MemGPT/issues/1138)
|
||||
if function_names is not None:
|
||||
preset_override = True
|
||||
# available_tools = self.ms.list_tools(user_id=user_id) # TODO: add back when user-specific
|
||||
available_tools = self.ms.list_tools()
|
||||
available_tools_names = [t.name for t in available_tools]
|
||||
assert all([f_name in available_tools_names for f_name in function_names])
|
||||
preset_obj.functions_schema = [t.json_schema for t in available_tools if t.name in function_names]
|
||||
print("overriding preset_obj tools with:", preset_obj.functions_schema)
|
||||
# get tools
|
||||
tool_objs = [self.ms.get_tool(name) for name in tools]
|
||||
|
||||
# If the user overrode any parts of the preset, we need to create a new preset to refer back to
|
||||
if preset_override:
|
||||
@@ -735,13 +746,24 @@ class SyncServer(LockingServer):
|
||||
# Then write out to the database for storage
|
||||
self.ms.create_preset(preset=preset_obj)
|
||||
|
||||
agent = Agent(
|
||||
interface=interface,
|
||||
preset=preset_obj,
|
||||
agent_state = AgentState(
|
||||
name=name,
|
||||
created_by=user.id,
|
||||
user_id=user_id,
|
||||
persona=preset_obj.persona_name, # TODO: remove
|
||||
human=preset_obj.human_name, # TODO: remove
|
||||
tools=tools, # name=id for tools
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
system=system,
|
||||
preset=preset, # TODO: remove
|
||||
state={"persona": preset_obj.persona, "human": preset_obj.human, "system": system, "messages": None},
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
interface=interface,
|
||||
agent_state=agent_state,
|
||||
tools=tool_objs,
|
||||
# embedding_config=embedding_config,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
|
||||
)
|
||||
@@ -852,11 +874,11 @@ class SyncServer(LockingServer):
|
||||
|
||||
# TODO remove this eventually when return type get pydanticfied
|
||||
# this is to add persona_name and human_name so that the columns in UI can populate
|
||||
preset = self.ms.get_preset(name=agent_state.preset, user_id=user_id)
|
||||
# TODO hack for frontend, remove
|
||||
# (top level .persona is persona_name, and nested memory.persona is the state)
|
||||
return_dict["persona"] = preset.persona_name
|
||||
return_dict["human"] = preset.human_name
|
||||
# TODO: eventually modify this to be contained in the metadata
|
||||
return_dict["persona"] = agent_state.human
|
||||
return_dict["human"] = agent_state.persona
|
||||
|
||||
# Add information about tools
|
||||
# TODO memgpt_agent should really have a field of List[ToolModel]
|
||||
@@ -1437,8 +1459,36 @@ class SyncServer(LockingServer):
|
||||
|
||||
return sources_with_metadata
|
||||
|
||||
def create_tool(self, name: str, user_id: uuid.UUID) -> ToolModel: # TODO: add other fields
|
||||
"""Create a new tool"""
|
||||
def create_tool(
|
||||
self, json_schema: dict, source_code: str, source_type: str, tags: Optional[List[str]] = None, exists_ok: Optional[bool] = True
|
||||
) -> ToolModel: # TODO: add other fields
|
||||
"""Create a new tool
|
||||
|
||||
def delete_tool(self, tool_id: uuid.UUID, user_id: uuid.UUID):
|
||||
Args:
|
||||
TODO
|
||||
|
||||
Returns:
|
||||
tool (ToolModel): Tool object
|
||||
"""
|
||||
name = json_schema["name"]
|
||||
tool = self.ms.get_tool(name)
|
||||
if tool: # check if function already exists
|
||||
if exists_ok:
|
||||
# update existing tool
|
||||
tool.json_schema = json_schema
|
||||
tool.tags = tags
|
||||
tool.source_code = source_code
|
||||
tool.source_type = source_type
|
||||
self.ms.update_tool(tool)
|
||||
else:
|
||||
raise ValueError(f"Tool with name {name} already exists.")
|
||||
else:
|
||||
# create new tool
|
||||
tool = ToolModel(name=name, json_schema=json_schema, tags=tags, source_code=source_code, source_type=source_type)
|
||||
self.ms.add_tool(tool)
|
||||
|
||||
return self.ms.get_tool(name)
|
||||
|
||||
def delete_tool(self, name: str):
|
||||
"""Delete a tool"""
|
||||
self.ms.delete_tool(name)
|
||||
|
||||
@@ -19,6 +19,9 @@ class Settings(BaseSettings):
|
||||
pg_uri: Optional[str] = None # option to specifiy full uri
|
||||
cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"]
|
||||
|
||||
# agent configuration defaults
|
||||
default_preset: Optional[str] = "memgpt_chat"
|
||||
|
||||
@property
|
||||
def memgpt_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
|
||||
11
poetry.lock
generated
11
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
@@ -2999,6 +2999,7 @@ description = "Nvidia JIT LTO Library"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_aarch64.whl", hash = "sha256:004186d5ea6a57758fd6d57052a123c73a4815adf365eb8dd6a85c9eaa7535ff"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d9714f27c1d0f0895cd8915c07a87a1d0029a0aa36acaf9156952ec2a8a12189"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.40-py3-none-win_amd64.whl", hash = "sha256:c3401dc8543b52d3a8158007a0c1ab4e9c768fcbd24153a48c86972102197ddd"},
|
||||
]
|
||||
@@ -3874,13 +3875,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.7.3"
|
||||
version = "2.7.4"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.7.3-py3-none-any.whl", hash = "sha256:ea91b002777bf643bb20dd717c028ec43216b24a6001a280f83877fd2655d0b4"},
|
||||
{file = "pydantic-2.7.3.tar.gz", hash = "sha256:c46c76a40bb1296728d7a8b99aa73dd70a48c3510111ff290034f860c99c419e"},
|
||||
{file = "pydantic-2.7.4-py3-none-any.whl", hash = "sha256:ee8538d41ccb9c0a9ad3e0e5f07bf15ed8015b481ced539a1759d8cc89ae90d0"},
|
||||
{file = "pydantic-2.7.4.tar.gz", hash = "sha256:0c84efd9548d545f63ac0060c1e4d39bb9b14db8b3c0652338aecc07b5adec52"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6340,4 +6341,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.10"
|
||||
content-hash = "bfa14c084ae06f7d5ceb561406794d93f90808c20b098af13110f4ebe38c7928"
|
||||
content-hash = "904e243813980d61b67db2d1dc96cc782299bcdfb80bb122e4f2013f11f2a9c4"
|
||||
|
||||
@@ -39,7 +39,7 @@ chromadb = "^0.5.0"
|
||||
sqlalchemy-json = "^0.7.0"
|
||||
fastapi = {version = "^0.104.1", optional = true}
|
||||
uvicorn = {version = "^0.24.0.post1", optional = true}
|
||||
pydantic = "^2.5.2"
|
||||
pydantic = "^2.7.4"
|
||||
pyautogen = {version = "0.2.22", optional = true}
|
||||
html2text = "^2020.1.16"
|
||||
docx2txt = "^0.8"
|
||||
|
||||
@@ -40,9 +40,7 @@ def agent():
|
||||
if not client.server.get_user(user_id=user_id):
|
||||
client.server.create_user({"id": user_id})
|
||||
|
||||
agent_state = client.create_agent(
|
||||
preset=constants.DEFAULT_PRESET,
|
||||
)
|
||||
agent_state = client.create_agent()
|
||||
|
||||
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import uuid
|
||||
import pytest
|
||||
|
||||
import memgpt.functions.function_sets.base as base_functions
|
||||
from memgpt import constants, create_client
|
||||
from memgpt import create_client
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
from .utils import create_config, wipe_config
|
||||
@@ -25,9 +25,7 @@ def agent_obj():
|
||||
|
||||
client = create_client()
|
||||
|
||||
agent_state = client.create_agent(
|
||||
preset=constants.DEFAULT_PRESET,
|
||||
)
|
||||
agent_state = client.create_agent()
|
||||
|
||||
global agent_obj
|
||||
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
|
||||
|
||||
@@ -110,9 +110,7 @@ def test_concurrent_messages(admin_client):
|
||||
response = admin_client.create_user()
|
||||
token = response.api_key
|
||||
client = create_client(base_url=admin_client.base_url, token=token)
|
||||
agent = client.create_agent(
|
||||
name=test_agent_name,
|
||||
)
|
||||
agent = client.create_agent()
|
||||
|
||||
print("Agent created", agent.id)
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ import os
|
||||
import uuid
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.data_types import AgentState, Message
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.llm_api.llm_api_tools import create
|
||||
from memgpt.models.pydantic_models import EmbeddingConfigModel, LLMConfigModel
|
||||
from memgpt.presets.presets import load_preset
|
||||
from memgpt.presets.presets import load_module_tools
|
||||
from memgpt.prompts import gpt_system
|
||||
|
||||
messages = [Message(role="system", text=gpt_system.get_system_text("memgpt_chat")), Message(role="user", text="How are you?")]
|
||||
@@ -26,13 +26,22 @@ def run_llm_endpoint(filename):
|
||||
print(config_data)
|
||||
llm_config = LLMConfigModel(**config_data)
|
||||
embedding_config = EmbeddingConfigModel(**json.load(open(embedding_config_path)))
|
||||
agent_state = AgentState(
|
||||
name="test_agent",
|
||||
tools=[tool.name for tool in load_module_tools()],
|
||||
system="",
|
||||
persona="",
|
||||
human="",
|
||||
preset="memgpt_chat",
|
||||
embedding_config=embedding_config,
|
||||
llm_config=llm_config,
|
||||
user_id=uuid.UUID(int=1),
|
||||
state={"persona": "", "human": "", "messages": None},
|
||||
)
|
||||
agent = Agent(
|
||||
interface=None,
|
||||
preset=load_preset("memgpt_chat", user_id=uuid.UUID(int=1)),
|
||||
name="test_agent",
|
||||
created_by=uuid.UUID(int=1),
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
tools=load_module_tools(),
|
||||
agent_state=agent_state,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True,
|
||||
)
|
||||
|
||||
@@ -8,13 +8,12 @@ from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.cli.cli_load import load_directory
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, User
|
||||
from memgpt.data_types import EmbeddingConfig, User
|
||||
from memgpt.metadata import MetadataStore
|
||||
|
||||
# from memgpt.data_sources.connectors import DirectoryConnector, load_data
|
||||
# import memgpt
|
||||
from memgpt.settings import settings
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
from .utils import create_config, wipe_config, with_qdrant_storage
|
||||
@@ -131,15 +130,17 @@ def test_load_directory(
|
||||
# config.save()
|
||||
|
||||
# create user and agent
|
||||
agent = AgentState(
|
||||
user_id=user.id,
|
||||
name="test_agent",
|
||||
preset=TEST_MEMGPT_CONFIG.preset,
|
||||
persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
|
||||
human=get_human_text(TEST_MEMGPT_CONFIG.human),
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
# agent = AgentState(
|
||||
# user_id=user.id,
|
||||
# name="test_agent",
|
||||
# preset=TEST_MEMGPT_CONFIG.preset,
|
||||
# persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
|
||||
# human=get_human_text(TEST_MEMGPT_CONFIG.human),
|
||||
# llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
# embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
# tools=[],
|
||||
# system="",
|
||||
# )
|
||||
ms.delete_user(user.id)
|
||||
ms.create_user(user)
|
||||
# ms.create_agent(agent)
|
||||
@@ -223,7 +224,6 @@ def test_load_directory(
|
||||
|
||||
# cleanup
|
||||
ms.delete_user(user.id)
|
||||
ms.delete_agent(agent.id)
|
||||
ms.delete_source(sources[0].id)
|
||||
|
||||
# revert to openai config
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from memgpt.migrate import migrate_all_agents
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
from .utils import create_config, wipe_config
|
||||
|
||||
|
||||
def test_migrate_0211():
|
||||
wipe_config()
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
else:
|
||||
create_config("memgpt_hosted")
|
||||
|
||||
data_dir = "tests/data/memgpt-0.2.11"
|
||||
tmp_dir = f"tmp_{str(uuid.uuid4())}"
|
||||
shutil.copytree(data_dir, tmp_dir)
|
||||
print("temporary directory:", tmp_dir)
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = os.path.join(data_dir, "config")
|
||||
# print(f"MEMGPT_CONFIG_PATH={os.environ['MEMGPT_CONFIG_PATH']}")
|
||||
try:
|
||||
agent_res = migrate_all_agents(tmp_dir, debug=True)
|
||||
assert len(agent_res["failed_migrations"]) == 0, f"Failed migrations: {agent_res}"
|
||||
|
||||
# NOTE: source tests had to be removed since it is no longer possible to migrate llama index vector indices
|
||||
# source_res = migrate_all_sources(tmp_dir)
|
||||
# assert len(source_res["failed_migrations"]) == 0, f"Failed migrations: {source_res}"
|
||||
|
||||
# TODO: assert everything is in the DB
|
||||
|
||||
server = SyncServer()
|
||||
for agent_name in agent_res["migration_candidates"]:
|
||||
if agent_name not in agent_res["failed_migrations"]:
|
||||
# assert agent data exists
|
||||
agent_state = server.ms.get_agent(agent_name=agent_name, user_id=agent_res["user_id"])
|
||||
assert agent_state is not None, f"Missing agent {agent_name}"
|
||||
|
||||
# assert in context messages exist
|
||||
message_ids = server.get_in_context_message_ids(user_id=agent_res["user_id"], agent_id=agent_state.id)
|
||||
assert len(message_ids) > 0
|
||||
|
||||
# assert recall memories exist
|
||||
messages = server.get_agent_messages(
|
||||
user_id=agent_state.user_id,
|
||||
agent_id=agent_state.id,
|
||||
start=0,
|
||||
count=1000,
|
||||
)
|
||||
assert len(messages) > 0
|
||||
|
||||
# for source_name in source_res["migration_candidates"]:
|
||||
# if source_name not in source_res["failed_migrations"]:
|
||||
# # assert source data exists
|
||||
# source = server.ms.get_source(source_name=source_name, user_id=source_res["user_id"])
|
||||
# assert source is not None
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir)
|
||||
@@ -6,6 +6,26 @@ import unittest.mock
|
||||
import pytest
|
||||
|
||||
from memgpt.cli.cli_config import add, delete, list
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from tests.utils import create_config
|
||||
|
||||
|
||||
def _reset_config():
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
create_config("memgpt_hosted")
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
config.save()
|
||||
credentials.save()
|
||||
print("_reset_config :: ", config.config_path)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This is a helper function.")
|
||||
@@ -31,6 +51,7 @@ def reset_env_variables(server_url, token):
|
||||
|
||||
|
||||
def test_crud_human(capsys):
|
||||
_reset_config()
|
||||
|
||||
server_url, token = unset_env_variables()
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import memgpt.utils as utils
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.presets.presets import load_module_tools
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
@@ -73,6 +74,7 @@ def agent_id(server, user_id):
|
||||
user_id=user_id,
|
||||
name="test_agent",
|
||||
preset="memgpt_chat",
|
||||
tools=[tool.name for tool in load_module_tools()],
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
|
||||
@@ -192,12 +192,12 @@ def test_storage(
|
||||
human=get_human_text(TEST_MEMGPT_CONFIG.human),
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
system="",
|
||||
tools=[],
|
||||
state={
|
||||
"persona": "",
|
||||
"human": "",
|
||||
"system": "",
|
||||
"functions": [],
|
||||
"messages": [],
|
||||
"messages": None,
|
||||
},
|
||||
)
|
||||
ms.create_user(user)
|
||||
|
||||
Reference in New Issue
Block a user