fix: various fixes to get create agent REST API to work (#1763)
This commit is contained in:
@@ -26,9 +26,9 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
|
||||
prompt_template: str = Field(
|
||||
default="{% for block in memory.values() %}"
|
||||
'<{{ block.name }} characters="{{ block.value|length }}/{{ block.limit }}">\n'
|
||||
'<{{ block.label }} characters="{{ block.value|length }}/{{ block.limit }}">\n'
|
||||
"{{ block.value }}\n"
|
||||
"</{{ block.name }}>"
|
||||
"</{{ block.label }}>"
|
||||
"{% if not loop.last %}\n{% endif %}"
|
||||
"{% endfor %}",
|
||||
description="Jinja2 template for compiling memory blocks into a prompt string",
|
||||
@@ -99,6 +99,10 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
else:
|
||||
return self.memory[name]
|
||||
|
||||
def get_blocks(self) -> List[Block]:
|
||||
"""Return a list of the blocks held inside the memory object"""
|
||||
return list(self.memory.values())
|
||||
|
||||
def link_block(self, name: str, block: Block, override: Optional[bool] = False):
|
||||
"""Link a new block to the memory object"""
|
||||
if not isinstance(block, Block):
|
||||
@@ -143,8 +147,10 @@ class BasicBlockMemory(Memory):
|
||||
super().__init__()
|
||||
for block in blocks:
|
||||
# TODO: centralize these internal schema validations
|
||||
assert block.name is not None and block.name != "", "each existing chat block must have a name"
|
||||
self.link_block(name=block.name, block=block)
|
||||
# assert block.name is not None and block.name != "", "each existing chat block must have a name"
|
||||
# self.link_block(name=block.name, block=block)
|
||||
assert block.label is not None and block.label != "", "each existing chat block must have a name"
|
||||
self.link_block(name=block.label, block=block)
|
||||
|
||||
def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore
|
||||
"""
|
||||
|
||||
@@ -17,6 +17,7 @@ from memgpt.schemas.memgpt_request import MemGPTRequest
|
||||
from memgpt.schemas.memgpt_response import MemGPTResponse
|
||||
from memgpt.schemas.memory import (
|
||||
ArchivalMemorySummary,
|
||||
BasicBlockMemory,
|
||||
CreateArchivalMemory,
|
||||
Memory,
|
||||
RecallMemorySummary,
|
||||
@@ -58,6 +59,11 @@ def create_agent(
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
agent.user_id = actor.id
|
||||
# TODO: sarah make general
|
||||
# TODO: eventually remove this
|
||||
assert agent.memory is not None # TODO: dont force this, can be None (use default human/person)
|
||||
blocks = agent.memory.get_blocks()
|
||||
agent.memory = BasicBlockMemory(blocks=blocks)
|
||||
|
||||
return server.create_agent(agent, user_id=actor.id)
|
||||
|
||||
|
||||
@@ -30,12 +30,18 @@ from memgpt.data_sources.connectors import DataConnector, load_data
|
||||
# Token,
|
||||
# User,
|
||||
# )
|
||||
from memgpt.functions.functions import generate_schema, load_function_set
|
||||
from memgpt.functions.functions import (
|
||||
generate_schema,
|
||||
load_function_set,
|
||||
parse_source_code,
|
||||
)
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
|
||||
# TODO use custom interface
|
||||
from memgpt.interface import AgentInterface # abstract
|
||||
from memgpt.interface import CLIInterface # for printing to terminal
|
||||
from memgpt.log import get_logger
|
||||
from memgpt.memory import get_memory_functions
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
@@ -753,16 +759,44 @@ class SyncServer(Server):
|
||||
|
||||
# get tools + make sure they exist
|
||||
tool_objs = []
|
||||
for tool_name in request.tools:
|
||||
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
||||
assert tool_obj, f"Tool {tool_name} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
if request.tools:
|
||||
for tool_name in request.tools:
|
||||
tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
||||
assert tool_obj, f"Tool {tool_name} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
assert request.memory is not None
|
||||
memory_functions = get_memory_functions(request.memory)
|
||||
for func_name, func in memory_functions.items():
|
||||
|
||||
if request.tools and func_name in request.tools:
|
||||
# tool already added
|
||||
continue
|
||||
source_code = parse_source_code(func)
|
||||
json_schema = generate_schema(func, func_name)
|
||||
source_type = "python"
|
||||
tags = ["memory", "memgpt-base"]
|
||||
tool = self.create_tool(
|
||||
request=ToolCreate(
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
tags=tags,
|
||||
json_schema=json_schema,
|
||||
user_id=user_id,
|
||||
),
|
||||
update=True,
|
||||
user_id=user_id,
|
||||
)
|
||||
tool_objs.append(tool)
|
||||
if not request.tools:
|
||||
request.tools = []
|
||||
request.tools.append(tool.name)
|
||||
|
||||
# TODO: save the agent state
|
||||
agent_state = AgentState(
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
tools=request.tools, # name=id for tools
|
||||
tools=request.tools if request.tools else [],
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
system=request.system,
|
||||
|
||||
Reference in New Issue
Block a user