fix: various fixes to get create agent REST API to work (#1763)

This commit is contained in:
Charles Packer
2024-09-22 12:33:28 -07:00
committed by GitHub
parent a9d8445e4a
commit 0b348f8bd9
3 changed files with 56 additions and 10 deletions

View File

@@ -26,9 +26,9 @@ class Memory(BaseModel, validate_assignment=True):
# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
prompt_template: str = Field(
default="{% for block in memory.values() %}"
'<{{ block.name }} characters="{{ block.value|length }}/{{ block.limit }}">\n'
'<{{ block.label }} characters="{{ block.value|length }}/{{ block.limit }}">\n'
"{{ block.value }}\n"
"</{{ block.name }}>"
"</{{ block.label }}>"
"{% if not loop.last %}\n{% endif %}"
"{% endfor %}",
description="Jinja2 template for compiling memory blocks into a prompt string",
@@ -99,6 +99,10 @@ class Memory(BaseModel, validate_assignment=True):
else:
return self.memory[name]
def get_blocks(self) -> List[Block]:
"""Return a list of the blocks held inside the memory object"""
return list(self.memory.values())
def link_block(self, name: str, block: Block, override: Optional[bool] = False):
"""Link a new block to the memory object"""
if not isinstance(block, Block):
@@ -143,8 +147,10 @@ class BasicBlockMemory(Memory):
super().__init__()
for block in blocks:
# TODO: centralize these internal schema validations
assert block.name is not None and block.name != "", "each existing chat block must have a name"
self.link_block(name=block.name, block=block)
# assert block.name is not None and block.name != "", "each existing chat block must have a name"
# self.link_block(name=block.name, block=block)
assert block.label is not None and block.label != "", "each existing chat block must have a name"
self.link_block(name=block.label, block=block)
def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore
"""

View File

@@ -17,6 +17,7 @@ from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.memory import (
ArchivalMemorySummary,
BasicBlockMemory,
CreateArchivalMemory,
Memory,
RecallMemorySummary,
@@ -58,6 +59,11 @@ def create_agent(
"""
actor = server.get_current_user()
agent.user_id = actor.id
# TODO: sarah make general
# TODO: eventually remove this
assert agent.memory is not None # TODO: dont force this, can be None (use default human/person)
blocks = agent.memory.get_blocks()
agent.memory = BasicBlockMemory(blocks=blocks)
return server.create_agent(agent, user_id=actor.id)

View File

@@ -30,12 +30,18 @@ from memgpt.data_sources.connectors import DataConnector, load_data
# Token,
# User,
# )
from memgpt.functions.functions import generate_schema, load_function_set
from memgpt.functions.functions import (
generate_schema,
load_function_set,
parse_source_code,
)
from memgpt.functions.schema_generator import generate_schema
# TODO use custom interface
from memgpt.interface import AgentInterface # abstract
from memgpt.interface import CLIInterface # for printing to terminal
from memgpt.log import get_logger
from memgpt.memory import get_memory_functions
from memgpt.metadata import MetadataStore
from memgpt.prompts import gpt_system
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
@@ -753,16 +759,44 @@ class SyncServer(Server):
# get tools + make sure they exist
tool_objs = []
for tool_name in request.tools:
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
assert tool_obj, f"Tool {tool_name} does not exist"
tool_objs.append(tool_obj)
if request.tools:
for tool_name in request.tools:
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
assert tool_obj, f"Tool {tool_name} does not exist"
tool_objs.append(tool_obj)
assert request.memory is not None
memory_functions = get_memory_functions(request.memory)
for func_name, func in memory_functions.items():
if request.tools and func_name in request.tools:
# tool already added
continue
source_code = parse_source_code(func)
json_schema = generate_schema(func, func_name)
source_type = "python"
tags = ["memory", "memgpt-base"]
tool = self.create_tool(
request=ToolCreate(
source_code=source_code,
source_type=source_type,
tags=tags,
json_schema=json_schema,
user_id=user_id,
),
update=True,
user_id=user_id,
)
tool_objs.append(tool)
if not request.tools:
request.tools = []
request.tools.append(tool.name)
# TODO: save the agent state
agent_state = AgentState(
name=request.name,
user_id=user_id,
tools=request.tools, # name=id for tools
tools=request.tools if request.tools else [],
llm_config=llm_config,
embedding_config=embedding_config,
system=request.system,