diff --git a/memgpt/schemas/memory.py b/memgpt/schemas/memory.py index ea15356a..57bb4fdb 100644 --- a/memgpt/schemas/memory.py +++ b/memgpt/schemas/memory.py @@ -26,9 +26,9 @@ class Memory(BaseModel, validate_assignment=True): # Memory.template is a Jinja2 template for compiling memory module into a prompt string. prompt_template: str = Field( default="{% for block in memory.values() %}" - '<{{ block.name }} characters="{{ block.value|length }}/{{ block.limit }}">\n' + '<{{ block.label }} characters="{{ block.value|length }}/{{ block.limit }}">\n' "{{ block.value }}\n" - "" + "" "{% if not loop.last %}\n{% endif %}" "{% endfor %}", description="Jinja2 template for compiling memory blocks into a prompt string", @@ -99,6 +99,10 @@ class Memory(BaseModel, validate_assignment=True): else: return self.memory[name] + def get_blocks(self) -> List[Block]: + """Return a list of the blocks held inside the memory object""" + return list(self.memory.values()) + def link_block(self, name: str, block: Block, override: Optional[bool] = False): """Link a new block to the memory object""" if not isinstance(block, Block): @@ -143,8 +147,10 @@ class BasicBlockMemory(Memory): super().__init__() for block in blocks: # TODO: centralize these internal schema validations - assert block.name is not None and block.name != "", "each existing chat block must have a name" - self.link_block(name=block.name, block=block) + # assert block.name is not None and block.name != "", "each existing chat block must have a name" + # self.link_block(name=block.name, block=block) + assert block.label is not None and block.label != "", "each existing chat block must have a name" + self.link_block(name=block.label, block=block) def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore """ diff --git a/memgpt/server/rest_api/routers/v1/agents.py b/memgpt/server/rest_api/routers/v1/agents.py index 809b7f62..51c255e7 100644 --- a/memgpt/server/rest_api/routers/v1/agents.py +++ b/memgpt/server/rest_api/routers/v1/agents.py @@ -17,6 +17,7 @@ from memgpt.schemas.memgpt_request import MemGPTRequest from memgpt.schemas.memgpt_response import MemGPTResponse from memgpt.schemas.memory import ( ArchivalMemorySummary, + BasicBlockMemory, CreateArchivalMemory, Memory, RecallMemorySummary, @@ -58,6 +59,11 @@ def create_agent( """ actor = server.get_current_user() agent.user_id = actor.id + # TODO: sarah make general + # TODO: eventually remove this + assert agent.memory is not None # TODO: dont force this, can be None (use default human/person) + blocks = agent.memory.get_blocks() + agent.memory = BasicBlockMemory(blocks=blocks) return server.create_agent(agent, user_id=actor.id) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index b9658fbf..76dac18e 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -30,12 +30,18 @@ from memgpt.data_sources.connectors import DataConnector, load_data # Token, # User, # ) -from memgpt.functions.functions import generate_schema, load_function_set +from memgpt.functions.functions import ( + generate_schema, + load_function_set, + parse_source_code, +) +from memgpt.functions.schema_generator import generate_schema # TODO use custom interface from memgpt.interface import AgentInterface # abstract from memgpt.interface import CLIInterface # for printing to terminal from memgpt.log import get_logger +from memgpt.memory import get_memory_functions from memgpt.metadata import MetadataStore from memgpt.prompts import gpt_system from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState @@ -753,16 +759,44 @@ class SyncServer(Server): # get tools + make sure they exist tool_objs = [] - for tool_name in request.tools: - tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id) - assert tool_obj, f"Tool {tool_name} does not exist" - tool_objs.append(tool_obj) + if request.tools: + for tool_name in request.tools: + tool_obj = self.ms.get_tool(tool_name=tool_name, user_id=user_id) + assert tool_obj, f"Tool {tool_name} does not exist" + tool_objs.append(tool_obj) + + assert request.memory is not None + memory_functions = get_memory_functions(request.memory) + for func_name, func in memory_functions.items(): + + if request.tools and func_name in request.tools: + # tool already added + continue + source_code = parse_source_code(func) + json_schema = generate_schema(func, func_name) + source_type = "python" + tags = ["memory", "memgpt-base"] + tool = self.create_tool( + request=ToolCreate( + source_code=source_code, + source_type=source_type, + tags=tags, + json_schema=json_schema, + user_id=user_id, + ), + update=True, + user_id=user_id, + ) + tool_objs.append(tool) + if not request.tools: + request.tools = [] + request.tools.append(tool.name) # TODO: save the agent state agent_state = AgentState( name=request.name, user_id=user_id, - tools=request.tools, # name=id for tools + tools=request.tools if request.tools else [], llm_config=llm_config, embedding_config=embedding_config, system=request.system,