From 282e7b52898e9f00bc0b9f062887e4ddf47edc16 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 3 Sep 2024 15:44:52 -0700 Subject: [PATCH] refactor: use `jinja2` templates for `Memory.compile` instead of writing Python code (#1687) --- memgpt/agent.py | 6 +- memgpt/main.py | 7 +- memgpt/memory.py | 206 +-------------------------------------- memgpt/schemas/block.py | 4 +- memgpt/schemas/memory.py | 80 ++++++++++++--- poetry.lock | 12 +-- pyproject.toml | 1 + tests/test_memory.py | 69 ++++++++++++- tests/test_new_client.py | 6 +- 9 files changed, 152 insertions(+), 239 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index f46031f7..c7995e12 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -698,7 +698,7 @@ class Agent(object): # TODO: ensure we're passing in metadata store from all surfaces if ms is not None: should_update = False - for block in self.agent_state.memory.to_dict().values(): + for block in self.agent_state.memory.to_dict()["memory"].values(): if not block.get("template", False): should_update = True if should_update: @@ -1020,7 +1020,7 @@ class Agent(object): return if ms: - for block in self.memory.to_dict().values(): + for block in self.memory.to_dict()["memory"].values(): if block.get("templates", False): # we don't expect to update shared memory blocks that # are templates. this is something we could update in the @@ -1225,7 +1225,7 @@ def save_agent_memory(agent: Agent, ms: MetadataStore): NOTE: we are assuming agent.update_state has already been called. """ - for block_dict in agent.memory.to_dict().values(): + for block_dict in agent.memory.to_dict()["memory"].values(): # TODO: block creation should happen in one place to enforce these sort of constraints consistently. if block_dict.get("user_id", None) is None: block_dict["user_id"] = agent.agent_state.user_id diff --git a/memgpt/main.py b/memgpt/main.py index 2bbc5411..d9639ba5 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -1,4 +1,3 @@ -import json import os import sys import traceback @@ -190,9 +189,9 @@ def run_agent_loop( elif user_input.lower() == "/memory": print(f"\nDumping memory contents:\n") - print(f"{str(memgpt_agent.memory)}") - print(f"{str(memgpt_agent.persistence_manager.archival_memory)}") - print(f"{str(memgpt_agent.persistence_manager.recall_memory)}") + print(f"{memgpt_agent.memory.compile()}") + print(f"{memgpt_agent.persistence_manager.archival_memory.compile()}") + print(f"{memgpt_agent.persistence_manager.recall_memory.compile()}") continue elif user_input.lower() == "/model": diff --git a/memgpt/memory.py b/memgpt/memory.py index 7543c5c7..f03bafb7 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -19,114 +19,6 @@ from memgpt.utils import ( validate_date_format, ) -# class MemoryModule(BaseModel): -# """Base class for memory modules""" -# -# description: Optional[str] = None -# limit: int = 2000 -# value: Optional[Union[List[str], str]] = None -# -# def __setattr__(self, name, value): -# """Run validation if self.value is updated""" -# super().__setattr__(name, value) -# if name == "value": -# # run validation -# self.__class__.validate(self.dict(exclude_unset=True)) -# -# @validator("value", always=True) -# def check_value_length(cls, v, values): -# if v is not None: -# # Fetching the limit from the values dictionary -# limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set -# -# # Check if the value exceeds the limit -# if isinstance(v, str): -# length = len(v) -# elif isinstance(v, list): -# length = sum(len(item) for item in v) -# else: -# raise ValueError("Value must be either a string or a list of strings.") -# -# if length > limit: -# error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})." -# # TODO: add archival memory error? -# raise ValueError(error_msg) -# return v -# -# def __len__(self): -# return len(str(self)) -# -# def __str__(self) -> str: -# if isinstance(self.value, list): -# return ",".join(self.value) -# elif isinstance(self.value, str): -# return self.value -# else: -# return "" -# -# -# class BaseMemory: -# -# def __init__(self): -# self.memory = {} -# -# @classmethod -# def load(cls, state: dict): -# """Load memory from dictionary object""" -# obj = cls() -# for key, value in state.items(): -# obj.memory[key] = MemoryModule(**value) -# return obj -# -# def __str__(self) -> str: -# """Representation of the memory in-context""" -# section_strs = [] -# for section, module in self.memory.items(): -# section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n') -# return "\n".join(section_strs) -# -# def to_dict(self): -# """Convert to dictionary representation""" -# return {key: value.dict() for key, value in self.memory.items()} -# -# -# class ChatMemory(BaseMemory): -# -# def __init__(self, persona: str, human: str, limit: int = 2000): -# self.memory = { -# "persona": MemoryModule(name="persona", value=persona, limit=limit), -# "human": MemoryModule(name="human", value=human, limit=limit), -# } -# -# def core_memory_append(self, name: str, content: str) -> Optional[str]: -# """ -# Append to the contents of core memory. -# -# Args: -# name (str): Section of the memory to be edited (persona or human). -# content (str): Content to write to the memory. All unicode (including emojis) are supported. -# -# Returns: -# Optional[str]: None is always returned as this function does not produce a response. -# """ -# self.memory[name].value += "\n" + content -# return None -# -# def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]: -# """ -# Replace the contents of core memory. To delete memories, use an empty string for new_content. -# -# Args: -# name (str): Section of the memory to be edited (persona or human). -# old_content (str): String to replace. Must be an exact match. -# new_content (str): Content to write to the memory. All unicode (including emojis) are supported. -# -# Returns: -# Optional[str]: None is always returned as this function does not produce a response. -# """ -# self.memory[name].value = self.memory[name].value.replace(old_content, new_content) -# return None - def get_memory_functions(cls: Memory) -> List[callable]: """Get memory functions for a memory class""" @@ -151,94 +43,6 @@ def get_memory_functions(cls: Memory) -> List[callable]: return functions -# class CoreMemory(object): -# """Held in-context inside the system message -# -# Core Memory: Refers to the system block, which provides essential, foundational context to the AI. -# This includes the persona information, essential user details, -# and any other baseline data you deem necessary for the AI's basic functioning. -# """ -# -# def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True): -# self.persona = persona -# self.human = human -# self.persona_char_limit = persona_char_limit -# self.human_char_limit = human_char_limit -# -# # affects the error message the AI will see on overflow inserts -# self.archival_memory_exists = archival_memory_exists -# -# def __repr__(self) -> str: -# return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}" -# -# def to_dict(self): -# return { -# "persona": self.persona, -# "human": self.human, -# } -# -# @classmethod -# def load(cls, state): -# return cls(state["persona"], state["human"]) -# -# def edit_persona(self, new_persona): -# if self.persona_char_limit and len(new_persona) > self.persona_char_limit: -# error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})." -# if self.archival_memory_exists: -# error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again." -# raise ValueError(error_msg) -# -# self.persona = new_persona -# return len(self.persona) -# -# def edit_human(self, new_human): -# if self.human_char_limit and len(new_human) > self.human_char_limit: -# error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})." -# if self.archival_memory_exists: -# error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again." -# raise ValueError(error_msg) -# -# self.human = new_human -# return len(self.human) -# -# def edit(self, field, content): -# if field == "persona": -# return self.edit_persona(content) -# elif field == "human": -# return self.edit_human(content) -# else: -# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') -# -# def edit_append(self, field, content, sep="\n"): -# if field == "persona": -# new_content = self.persona + sep + content -# return self.edit_persona(new_content) -# elif field == "human": -# new_content = self.human + sep + content -# return self.edit_human(new_content) -# else: -# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') -# -# def edit_replace(self, field, old_content, new_content): -# if len(old_content) == 0: -# raise ValueError("old_content cannot be an empty string (must specify old_content to replace)") -# -# if field == "persona": -# if old_content in self.persona: -# new_persona = self.persona.replace(old_content, new_content) -# return self.edit_persona(new_persona) -# else: -# raise ValueError("Content not found in persona (make sure to use exact string)") -# elif field == "human": -# if old_content in self.human: -# new_human = self.human.replace(old_content, new_content) -# return self.edit_human(new_human) -# else: -# raise ValueError("Content not found in human (make sure to use exact string)") -# else: -# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")') - - def _format_summary_history(message_history: List[Message]): # TODO use existing prompt formatters for this (eg ChatML) return "\n".join([f"{m.role}: {m.text}" for m in message_history]) @@ -308,7 +112,7 @@ class ArchivalMemory(ABC): """ @abstractmethod - def __repr__(self) -> str: + def compile(self) -> str: pass @@ -322,7 +126,7 @@ class RecallMemory(ABC): """Search messages between start_date and end_date in recall memory""" @abstractmethod - def __repr__(self) -> str: + def compile(self) -> str: pass @abstractmethod @@ -350,7 +154,7 @@ class DummyRecallMemory(RecallMemory): def __len__(self): return len(self._message_logs) - def __repr__(self) -> str: + def compile(self) -> str: # don't dump all the conversations, just statistics system_count = user_count = assistant_count = function_count = other_count = 0 for msg in self._message_logs: @@ -467,7 +271,7 @@ class BaseRecallMemory(RecallMemory): results_json = [message.to_openai_dict_search_results() for message in results] return results_json, len(results) - def __repr__(self) -> str: + def compile(self) -> str: total = self.storage.size() system_count = self.storage.size(filters={"role": "system"}) user_count = self.storage.size(filters={"role": "user"}) @@ -597,7 +401,7 @@ class EmbeddingArchivalMemory(ArchivalMemory): print("Archival search error", e) raise e - def __repr__(self) -> str: + def compile(self) -> str: limit = 10 passages = [] for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10 diff --git a/memgpt/schemas/block.py b/memgpt/schemas/block.py index 79b8ef85..dd43c305 100644 --- a/memgpt/schemas/block.py +++ b/memgpt/schemas/block.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional from pydantic import Field, model_validator from typing_extensions import Self @@ -14,7 +14,7 @@ class BaseBlock(MemGPTBase, validate_assignment=True): __id_prefix__ = "block" # data value - value: Optional[Union[List[str], str]] = Field(None, description="Value of the block.") + value: Optional[str] = Field(None, description="Value of the block.") limit: int = Field(2000, description="Character limit of the block.") name: Optional[str] = Field(None, description="Name of the block.") diff --git a/memgpt/schemas/memory.py b/memgpt/schemas/memory.py index 949f832f..1685af9f 100644 --- a/memgpt/schemas/memory.py +++ b/memgpt/schemas/memory.py @@ -1,34 +1,81 @@ -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional +from jinja2 import Template, TemplateSyntaxError from pydantic import BaseModel, Field +# Forward referencing to avoid circular import with Agent -> Memory -> Agent +if TYPE_CHECKING: + from memgpt.agent import Agent + from memgpt.schemas.block import Block class Memory(BaseModel, validate_assignment=True): """Represents the in-context memory of the agent""" - # Private variable to avoid assignments with incorrect types + # Memory.memory is a dict mapping from memory block section to memory block. memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.") + # Memory.template is a Jinja2 template for compiling memory module into a prompt string. + prompt_template: str = Field( + default="{% for section, block in memory.items() %}" + '<{{ section }} characters="{{ block.value|length }}/{{ block.limit }}">\n' + "{{ block.value }}\n" + "" + "{% if not loop.last %}\n{% endif %}" + "{% endfor %}", + description="Jinja2 template for compiling memory blocks into a prompt string", + ) + + def get_prompt_template(self) -> str: + """Return the current Jinja2 template string.""" + return str(self.prompt_template) + + def set_prompt_template(self, prompt_template: str): + """ + Set a new Jinja2 template string. + Validates the template syntax and compatibility with current memory structure. + """ + try: + # Validate Jinja2 syntax + Template(prompt_template) + + # Validate compatibility with current memory structure + test_render = Template(prompt_template).render(memory=self.memory) + + # If we get here, the template is valid and compatible + self.prompt_template = prompt_template + except TemplateSyntaxError as e: + raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}") + except Exception as e: + raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + @classmethod def load(cls, state: dict): """Load memory from dictionary object""" obj = cls() - for key, value in state.items(): - obj.memory[key] = Block(**value) + if len(state.keys()) == 2 and "memory" in state and "prompt_template" in state: + # New format + obj.prompt_template = state["prompt_template"] + for key, value in state["memory"].items(): + obj.memory[key] = Block(**value) + else: + # Old format (pre-template) + for key, value in state.items(): + obj.memory[key] = Block(**value) return obj def compile(self) -> str: - """Generate a string representation of the memory in-context""" - section_strs = [] - for section, module in self.memory.items(): - section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n') - return "\n".join(section_strs) + """Generate a string representation of the memory in-context using the Jinja2 template""" + template = Template(self.prompt_template) + return template.render(memory=self.memory) def to_dict(self): """Convert to dictionary representation""" - return {key: value.dict() for key, value in self.memory.items()} + return { + "memory": {key: value.model_dump() for key, value in self.memory.items()}, + "prompt_template": self.prompt_template, + } def to_flat_dict(self): """Convert to a dictionary that maps directly from block names to values""" @@ -41,7 +88,7 @@ class Memory(BaseModel, validate_assignment=True): def get_block(self, name: str) -> Block: """Correct way to index into the memory.memory field, returns a Block""" if name not in self.memory: - return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})") + raise KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})") else: return self.memory[name] @@ -56,19 +103,20 @@ class Memory(BaseModel, validate_assignment=True): self.memory[name] = block - def update_block_value(self, name: str, value: Union[List[str], str]): + def update_block_value(self, name: str, value: str): """Update the value of a block""" if name not in self.memory: raise ValueError(f"Block with name {name} does not exist") - if not (isinstance(value, str) or (isinstance(value, list) and all(isinstance(v, str) for v in value))): - raise ValueError(f"Provided value must be a string or list of strings") + if not isinstance(value, str): + raise ValueError(f"Provided value must be a string") self.memory[name].value = value # TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names. class BaseChatMemory(Memory): - def core_memory_append(self, name: str, content: str) -> Optional[str]: + + def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore """ Append to the contents of core memory. @@ -84,7 +132,7 @@ class BaseChatMemory(Memory): self.memory.update_block_value(name=name, value=new_value) return None - def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]: + def core_memory_replace(self: "Agent", name: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore """ Replace the contents of core memory. To delete memories, use an empty string for new_content. diff --git a/poetry.lock b/poetry.lock index 2bfb9873..2e3f4d20 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2557,7 +2557,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] name = "jinja2" version = "3.1.4" description = "A very fast and expressive template engine." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, @@ -3367,7 +3367,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] name = "markupsafe" version = "2.1.5" description = "Safely add untrusted strings to HTML/XML markup." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, @@ -3502,6 +3502,7 @@ python-versions = ">=3.7" files = [ {file = "milvus_lite-2.4.9-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d3e617b3d68c09ad656d54bc3d8cc4ef6ef56c54015e1563d4fe4bcec6b7c90a"}, {file = "milvus_lite-2.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6e7029282d6829b277ebb92f64e2370be72b938e34770e1eb649346bda5d1d7f"}, + {file = "milvus_lite-2.4.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9b8e991e4e433596f6a399a165c1a506f823ec9133332e03d7f8a114bff4550d"}, {file = "milvus_lite-2.4.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:7f53e674602101cfbcf0a4a59d19eaa139dfd5580639f3040ad73d901f24fc0b"}, ] @@ -7129,11 +7130,6 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, - {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, - {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, - {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, - {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, - {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -7965,4 +7961,4 @@ server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.13,>=3.10" -content-hash = "3195feb8a0715fb8a8a191e6c402ae6c01f991921ac7a5629a333e0556d7d02a" +content-hash = "835cba3934be7d79b7f2eb2c5c4c81af2eebe432ed55b186b1616022605e2f70" diff --git a/pyproject.toml b/pyproject.toml index 2b7fbe8b..659c5303 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ crewai-tools = {version = "^0.8.3", optional = true} docker = {version = "^7.1.0", optional = true} tiktoken = "^0.7.0" nltk = "^3.8.1" +jinja2 = "^3.1.4" [tool.poetry.extras] local = ["llama-index-embeddings-huggingface"] diff --git a/tests/test_memory.py b/tests/test_memory.py index 8f8057f3..39880b2e 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -18,7 +18,7 @@ def test_create_chat_memory(): def test_dump_memory_as_json(sample_memory: Memory): """Test dumping ChatMemory as JSON compatible dictionary""" - memory_dict = sample_memory.to_dict() + memory_dict = sample_memory.to_dict()["memory"] assert isinstance(memory_dict, dict) assert "persona" in memory_dict assert memory_dict["persona"]["value"] == "Chat Agent" @@ -26,7 +26,7 @@ def test_dump_memory_as_json(sample_memory: Memory): def test_load_memory_from_json(sample_memory: Memory): """Test loading ChatMemory from a JSON compatible dictionary""" - memory_dict = sample_memory.to_dict() + memory_dict = sample_memory.to_dict()["memory"] print(memory_dict) new_memory = Memory.load(memory_dict) assert new_memory.get_block("persona").value == "Chat Agent" @@ -63,3 +63,68 @@ def test_memory_limit_validation(sample_memory: Memory): with pytest.raises(ValueError): sample_memory.get_block("persona").value = "x" * 3000 + + +def test_memory_jinja2_template_load(sample_memory: Memory): + """Test loading a memory with and without a jinja2 template""" + + # Test loading a memory with a template + memory_dict = sample_memory.to_dict() + memory_dict["prompt_template"] = sample_memory.get_prompt_template() + new_memory = Memory.load(memory_dict) + assert new_memory.get_prompt_template() == sample_memory.get_prompt_template() + + # Test loading a memory without a template (old format) + memory_dict = sample_memory.to_dict() + memory_dict_old_format = memory_dict["memory"] + new_memory = Memory.load(memory_dict_old_format) + assert new_memory.get_prompt_template() is not None # Ensure a default template is set + assert new_memory.to_dict()["memory"] == memory_dict_old_format + + +def test_memory_jinja2_template(sample_memory: Memory): + """Test to make sure the jinja2 template string is equivalent to the old __repr__ method""" + + def old_repr(self: Memory) -> str: + """Generate a string representation of the memory in-context""" + section_strs = [] + for section, module in self.memory.items(): + section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n') + return "\n".join(section_strs) + + old_repr_str = old_repr(sample_memory) + new_repr_str = sample_memory.compile() + assert new_repr_str == old_repr_str, f"Expected '{old_repr_str}' to be '{new_repr_str}'" + + +def test_memory_jinja2_set_template(sample_memory: Memory): + """Test setting the template for the memory""" + + example_template = sample_memory.get_prompt_template() + + # Try setting a valid template + sample_memory.set_prompt_template(prompt_template=example_template) + + # Try setting an invalid template (bad jinja2) + template_bad_jinja = ( + "{% for section, module in mammoth.items() %}" + '<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n' + "{{ module.value }}\n" + "" + "{% if not loop.last %}\n{% endif %}" + "{% endfor %" # Missing closing curly brace + ) + with pytest.raises(ValueError): + sample_memory.set_prompt_template(prompt_template=template_bad_jinja) + + # Try setting an invalid template (not compatible with memory structure) + template_bad_memory_structure = ( + "{% for section, module in mammoth.items() %}" + '<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n' + "{{ module.value }}\n" + "" + "{% if not loop.last %}\n{% endif %}" + "{% endfor %}" + ) + with pytest.raises(ValueError): + sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure) diff --git a/tests/test_new_client.py b/tests/test_new_client.py index 0e2bc4b8..d4442565 100644 --- a/tests/test_new_client.py +++ b/tests/test_new_client.py @@ -42,7 +42,7 @@ def test_agent(client: Union[LocalClient, RESTClient]): print("TOOLS", [t.name for t in tools]) agent_state = client.get_agent(agent_state_test.id) assert agent_state.name == "test_agent2" - for block in agent_state.memory.to_dict().values(): + for block in agent_state.memory.to_dict()["memory"].values(): db_block = client.server.ms.get_block(block.get("id")) assert db_block is not None, "memory block not persisted on agent create" assert db_block.value == block.get("value"), "persisted block data does not match in-memory data" @@ -134,7 +134,7 @@ def test_agent_with_shared_blocks(client): ) assert isinstance(first_agent_state_test.memory, Memory) - first_blocks_dict = first_agent_state_test.memory.to_dict() + first_blocks_dict = first_agent_state_test.memory.to_dict()["memory"] assert persona_block.id == first_blocks_dict.get("persona", {}).get("id") assert human_block.id == first_blocks_dict.get("human", {}).get("id") client.update_in_context_memory(first_agent_state_test.id, section="human", value="I'm an analyst therapist.") @@ -148,7 +148,7 @@ def test_agent_with_shared_blocks(client): ) assert isinstance(second_agent_state_test.memory, Memory) - second_blocks_dict = second_agent_state_test.memory.to_dict() + second_blocks_dict = second_agent_state_test.memory.to_dict()["memory"] assert persona_block.id == second_blocks_dict.get("persona", {}).get("id") assert human_block.id == second_blocks_dict.get("human", {}).get("id") assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist."