refactor: use jinja2 templates for Memory.compile instead of writing Python code (#1687)

This commit is contained in:
Charles Packer
2024-09-03 15:44:52 -07:00
committed by GitHub
parent 759b78a553
commit 282e7b5289
9 changed files with 152 additions and 239 deletions

View File

@@ -698,7 +698,7 @@ class Agent(object):
# TODO: ensure we're passing in metadata store from all surfaces
if ms is not None:
should_update = False
for block in self.agent_state.memory.to_dict().values():
for block in self.agent_state.memory.to_dict()["memory"].values():
if not block.get("template", False):
should_update = True
if should_update:
@@ -1020,7 +1020,7 @@ class Agent(object):
return
if ms:
for block in self.memory.to_dict().values():
for block in self.memory.to_dict()["memory"].values():
if block.get("templates", False):
# we don't expect to update shared memory blocks that
# are templates. this is something we could update in the
@@ -1225,7 +1225,7 @@ def save_agent_memory(agent: Agent, ms: MetadataStore):
NOTE: we are assuming agent.update_state has already been called.
"""
for block_dict in agent.memory.to_dict().values():
for block_dict in agent.memory.to_dict()["memory"].values():
# TODO: block creation should happen in one place to enforce these sort of constraints consistently.
if block_dict.get("user_id", None) is None:
block_dict["user_id"] = agent.agent_state.user_id

View File

@@ -1,4 +1,3 @@
import json
import os
import sys
import traceback
@@ -190,9 +189,9 @@ def run_agent_loop(
elif user_input.lower() == "/memory":
print(f"\nDumping memory contents:\n")
print(f"{str(memgpt_agent.memory)}")
print(f"{str(memgpt_agent.persistence_manager.archival_memory)}")
print(f"{str(memgpt_agent.persistence_manager.recall_memory)}")
print(f"{memgpt_agent.memory.compile()}")
print(f"{memgpt_agent.persistence_manager.archival_memory.compile()}")
print(f"{memgpt_agent.persistence_manager.recall_memory.compile()}")
continue
elif user_input.lower() == "/model":

View File

@@ -19,114 +19,6 @@ from memgpt.utils import (
validate_date_format,
)
# class MemoryModule(BaseModel):
# """Base class for memory modules"""
#
# description: Optional[str] = None
# limit: int = 2000
# value: Optional[Union[List[str], str]] = None
#
# def __setattr__(self, name, value):
# """Run validation if self.value is updated"""
# super().__setattr__(name, value)
# if name == "value":
# # run validation
# self.__class__.validate(self.dict(exclude_unset=True))
#
# @validator("value", always=True)
# def check_value_length(cls, v, values):
# if v is not None:
# # Fetching the limit from the values dictionary
# limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
#
# # Check if the value exceeds the limit
# if isinstance(v, str):
# length = len(v)
# elif isinstance(v, list):
# length = sum(len(item) for item in v)
# else:
# raise ValueError("Value must be either a string or a list of strings.")
#
# if length > limit:
# error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# # TODO: add archival memory error?
# raise ValueError(error_msg)
# return v
#
# def __len__(self):
# return len(str(self))
#
# def __str__(self) -> str:
# if isinstance(self.value, list):
# return ",".join(self.value)
# elif isinstance(self.value, str):
# return self.value
# else:
# return ""
#
#
# class BaseMemory:
#
# def __init__(self):
# self.memory = {}
#
# @classmethod
# def load(cls, state: dict):
# """Load memory from dictionary object"""
# obj = cls()
# for key, value in state.items():
# obj.memory[key] = MemoryModule(**value)
# return obj
#
# def __str__(self) -> str:
# """Representation of the memory in-context"""
# section_strs = []
# for section, module in self.memory.items():
# section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
# return "\n".join(section_strs)
#
# def to_dict(self):
# """Convert to dictionary representation"""
# return {key: value.dict() for key, value in self.memory.items()}
#
#
# class ChatMemory(BaseMemory):
#
# def __init__(self, persona: str, human: str, limit: int = 2000):
# self.memory = {
# "persona": MemoryModule(name="persona", value=persona, limit=limit),
# "human": MemoryModule(name="human", value=human, limit=limit),
# }
#
# def core_memory_append(self, name: str, content: str) -> Optional[str]:
# """
# Append to the contents of core memory.
#
# Args:
# name (str): Section of the memory to be edited (persona or human).
# content (str): Content to write to the memory. All unicode (including emojis) are supported.
#
# Returns:
# Optional[str]: None is always returned as this function does not produce a response.
# """
# self.memory[name].value += "\n" + content
# return None
#
# def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
# """
# Replace the contents of core memory. To delete memories, use an empty string for new_content.
#
# Args:
# name (str): Section of the memory to be edited (persona or human).
# old_content (str): String to replace. Must be an exact match.
# new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
#
# Returns:
# Optional[str]: None is always returned as this function does not produce a response.
# """
# self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
# return None
def get_memory_functions(cls: Memory) -> List[callable]:
"""Get memory functions for a memory class"""
@@ -151,94 +43,6 @@ def get_memory_functions(cls: Memory) -> List[callable]:
return functions
# class CoreMemory(object):
# """Held in-context inside the system message
#
# Core Memory: Refers to the system block, which provides essential, foundational context to the AI.
# This includes the persona information, essential user details,
# and any other baseline data you deem necessary for the AI's basic functioning.
# """
#
# def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True):
# self.persona = persona
# self.human = human
# self.persona_char_limit = persona_char_limit
# self.human_char_limit = human_char_limit
#
# # affects the error message the AI will see on overflow inserts
# self.archival_memory_exists = archival_memory_exists
#
# def __repr__(self) -> str:
# return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
#
# def to_dict(self):
# return {
# "persona": self.persona,
# "human": self.human,
# }
#
# @classmethod
# def load(cls, state):
# return cls(state["persona"], state["human"])
#
# def edit_persona(self, new_persona):
# if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
# error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})."
# if self.archival_memory_exists:
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
# raise ValueError(error_msg)
#
# self.persona = new_persona
# return len(self.persona)
#
# def edit_human(self, new_human):
# if self.human_char_limit and len(new_human) > self.human_char_limit:
# error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})."
# if self.archival_memory_exists:
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
# raise ValueError(error_msg)
#
# self.human = new_human
# return len(self.human)
#
# def edit(self, field, content):
# if field == "persona":
# return self.edit_persona(content)
# elif field == "human":
# return self.edit_human(content)
# else:
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
#
# def edit_append(self, field, content, sep="\n"):
# if field == "persona":
# new_content = self.persona + sep + content
# return self.edit_persona(new_content)
# elif field == "human":
# new_content = self.human + sep + content
# return self.edit_human(new_content)
# else:
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
#
# def edit_replace(self, field, old_content, new_content):
# if len(old_content) == 0:
# raise ValueError("old_content cannot be an empty string (must specify old_content to replace)")
#
# if field == "persona":
# if old_content in self.persona:
# new_persona = self.persona.replace(old_content, new_content)
# return self.edit_persona(new_persona)
# else:
# raise ValueError("Content not found in persona (make sure to use exact string)")
# elif field == "human":
# if old_content in self.human:
# new_human = self.human.replace(old_content, new_content)
# return self.edit_human(new_human)
# else:
# raise ValueError("Content not found in human (make sure to use exact string)")
# else:
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
def _format_summary_history(message_history: List[Message]):
# TODO use existing prompt formatters for this (eg ChatML)
return "\n".join([f"{m.role}: {m.text}" for m in message_history])
@@ -308,7 +112,7 @@ class ArchivalMemory(ABC):
"""
@abstractmethod
def __repr__(self) -> str:
def compile(self) -> str:
pass
@@ -322,7 +126,7 @@ class RecallMemory(ABC):
"""Search messages between start_date and end_date in recall memory"""
@abstractmethod
def __repr__(self) -> str:
def compile(self) -> str:
pass
@abstractmethod
@@ -350,7 +154,7 @@ class DummyRecallMemory(RecallMemory):
def __len__(self):
return len(self._message_logs)
def __repr__(self) -> str:
def compile(self) -> str:
# don't dump all the conversations, just statistics
system_count = user_count = assistant_count = function_count = other_count = 0
for msg in self._message_logs:
@@ -467,7 +271,7 @@ class BaseRecallMemory(RecallMemory):
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)
def __repr__(self) -> str:
def compile(self) -> str:
total = self.storage.size()
system_count = self.storage.size(filters={"role": "system"})
user_count = self.storage.size(filters={"role": "user"})
@@ -597,7 +401,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
print("Archival search error", e)
raise e
def __repr__(self) -> str:
def compile(self) -> str:
limit = 10
passages = []
for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import Optional
from pydantic import Field, model_validator
from typing_extensions import Self
@@ -14,7 +14,7 @@ class BaseBlock(MemGPTBase, validate_assignment=True):
__id_prefix__ = "block"
# data value
value: Optional[Union[List[str], str]] = Field(None, description="Value of the block.")
value: Optional[str] = Field(None, description="Value of the block.")
limit: int = Field(2000, description="Character limit of the block.")
name: Optional[str] = Field(None, description="Name of the block.")

View File

@@ -1,34 +1,81 @@
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional
from jinja2 import Template, TemplateSyntaxError
from pydantic import BaseModel, Field
# Forward referencing to avoid circular import with Agent -> Memory -> Agent
if TYPE_CHECKING:
from memgpt.agent import Agent
from memgpt.schemas.block import Block
class Memory(BaseModel, validate_assignment=True):
"""Represents the in-context memory of the agent"""
# Private variable to avoid assignments with incorrect types
# Memory.memory is a dict mapping from memory block section to memory block.
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")
# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
prompt_template: str = Field(
default="{% for section, block in memory.items() %}"
'<{{ section }} characters="{{ block.value|length }}/{{ block.limit }}">\n'
"{{ block.value }}\n"
"</{{ section }}>"
"{% if not loop.last %}\n{% endif %}"
"{% endfor %}",
description="Jinja2 template for compiling memory blocks into a prompt string",
)
def get_prompt_template(self) -> str:
"""Return the current Jinja2 template string."""
return str(self.prompt_template)
def set_prompt_template(self, prompt_template: str):
"""
Set a new Jinja2 template string.
Validates the template syntax and compatibility with current memory structure.
"""
try:
# Validate Jinja2 syntax
Template(prompt_template)
# Validate compatibility with current memory structure
test_render = Template(prompt_template).render(memory=self.memory)
# If we get here, the template is valid and compatible
self.prompt_template = prompt_template
except TemplateSyntaxError as e:
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
@classmethod
def load(cls, state: dict):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
obj.memory[key] = Block(**value)
if len(state.keys()) == 2 and "memory" in state and "prompt_template" in state:
# New format
obj.prompt_template = state["prompt_template"]
for key, value in state["memory"].items():
obj.memory[key] = Block(**value)
else:
# Old format (pre-template)
for key, value in state.items():
obj.memory[key] = Block(**value)
return obj
def compile(self) -> str:
"""Generate a string representation of the memory in-context"""
section_strs = []
for section, module in self.memory.items():
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
return "\n".join(section_strs)
"""Generate a string representation of the memory in-context using the Jinja2 template"""
template = Template(self.prompt_template)
return template.render(memory=self.memory)
def to_dict(self):
"""Convert to dictionary representation"""
return {key: value.dict() for key, value in self.memory.items()}
return {
"memory": {key: value.model_dump() for key, value in self.memory.items()},
"prompt_template": self.prompt_template,
}
def to_flat_dict(self):
"""Convert to a dictionary that maps directly from block names to values"""
@@ -41,7 +88,7 @@ class Memory(BaseModel, validate_assignment=True):
def get_block(self, name: str) -> Block:
"""Correct way to index into the memory.memory field, returns a Block"""
if name not in self.memory:
return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
raise KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
else:
return self.memory[name]
@@ -56,19 +103,20 @@ class Memory(BaseModel, validate_assignment=True):
self.memory[name] = block
def update_block_value(self, name: str, value: Union[List[str], str]):
def update_block_value(self, name: str, value: str):
"""Update the value of a block"""
if name not in self.memory:
raise ValueError(f"Block with name {name} does not exist")
if not (isinstance(value, str) or (isinstance(value, list) and all(isinstance(v, str) for v in value))):
raise ValueError(f"Provided value must be a string or list of strings")
if not isinstance(value, str):
raise ValueError(f"Provided value must be a string")
self.memory[name].value = value
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
class BaseChatMemory(Memory):
def core_memory_append(self, name: str, content: str) -> Optional[str]:
def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore
"""
Append to the contents of core memory.
@@ -84,7 +132,7 @@ class BaseChatMemory(Memory):
self.memory.update_block_value(name=name, value=new_value)
return None
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
def core_memory_replace(self: "Agent", name: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.

12
poetry.lock generated
View File

@@ -2557,7 +2557,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
name = "jinja2"
version = "3.1.4"
description = "A very fast and expressive template engine."
optional = true
optional = false
python-versions = ">=3.7"
files = [
{file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"},
@@ -3367,7 +3367,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
name = "markupsafe"
version = "2.1.5"
description = "Safely add untrusted strings to HTML/XML markup."
optional = true
optional = false
python-versions = ">=3.7"
files = [
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
@@ -3502,6 +3502,7 @@ python-versions = ">=3.7"
files = [
{file = "milvus_lite-2.4.9-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d3e617b3d68c09ad656d54bc3d8cc4ef6ef56c54015e1563d4fe4bcec6b7c90a"},
{file = "milvus_lite-2.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6e7029282d6829b277ebb92f64e2370be72b938e34770e1eb649346bda5d1d7f"},
{file = "milvus_lite-2.4.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9b8e991e4e433596f6a399a165c1a506f823ec9133332e03d7f8a114bff4550d"},
{file = "milvus_lite-2.4.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:7f53e674602101cfbcf0a4a59d19eaa139dfd5580639f3040ad73d901f24fc0b"},
]
@@ -7129,11 +7130,6 @@ files = [
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
]
[package.dependencies]
@@ -7965,4 +7961,4 @@ server = ["fastapi", "uvicorn", "websockets"]
[metadata]
lock-version = "2.0"
python-versions = "<3.13,>=3.10"
content-hash = "3195feb8a0715fb8a8a191e6c402ae6c01f991921ac7a5629a333e0556d7d02a"
content-hash = "835cba3934be7d79b7f2eb2c5c4c81af2eebe432ed55b186b1616022605e2f70"

View File

@@ -67,6 +67,7 @@ crewai-tools = {version = "^0.8.3", optional = true}
docker = {version = "^7.1.0", optional = true}
tiktoken = "^0.7.0"
nltk = "^3.8.1"
jinja2 = "^3.1.4"
[tool.poetry.extras]
local = ["llama-index-embeddings-huggingface"]

View File

@@ -18,7 +18,7 @@ def test_create_chat_memory():
def test_dump_memory_as_json(sample_memory: Memory):
"""Test dumping ChatMemory as JSON compatible dictionary"""
memory_dict = sample_memory.to_dict()
memory_dict = sample_memory.to_dict()["memory"]
assert isinstance(memory_dict, dict)
assert "persona" in memory_dict
assert memory_dict["persona"]["value"] == "Chat Agent"
@@ -26,7 +26,7 @@ def test_dump_memory_as_json(sample_memory: Memory):
def test_load_memory_from_json(sample_memory: Memory):
"""Test loading ChatMemory from a JSON compatible dictionary"""
memory_dict = sample_memory.to_dict()
memory_dict = sample_memory.to_dict()["memory"]
print(memory_dict)
new_memory = Memory.load(memory_dict)
assert new_memory.get_block("persona").value == "Chat Agent"
@@ -63,3 +63,68 @@ def test_memory_limit_validation(sample_memory: Memory):
with pytest.raises(ValueError):
sample_memory.get_block("persona").value = "x" * 3000
def test_memory_jinja2_template_load(sample_memory: Memory):
"""Test loading a memory with and without a jinja2 template"""
# Test loading a memory with a template
memory_dict = sample_memory.to_dict()
memory_dict["prompt_template"] = sample_memory.get_prompt_template()
new_memory = Memory.load(memory_dict)
assert new_memory.get_prompt_template() == sample_memory.get_prompt_template()
# Test loading a memory without a template (old format)
memory_dict = sample_memory.to_dict()
memory_dict_old_format = memory_dict["memory"]
new_memory = Memory.load(memory_dict_old_format)
assert new_memory.get_prompt_template() is not None # Ensure a default template is set
assert new_memory.to_dict()["memory"] == memory_dict_old_format
def test_memory_jinja2_template(sample_memory: Memory):
"""Test to make sure the jinja2 template string is equivalent to the old __repr__ method"""
def old_repr(self: Memory) -> str:
"""Generate a string representation of the memory in-context"""
section_strs = []
for section, module in self.memory.items():
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
return "\n".join(section_strs)
old_repr_str = old_repr(sample_memory)
new_repr_str = sample_memory.compile()
assert new_repr_str == old_repr_str, f"Expected '{old_repr_str}' to be '{new_repr_str}'"
def test_memory_jinja2_set_template(sample_memory: Memory):
"""Test setting the template for the memory"""
example_template = sample_memory.get_prompt_template()
# Try setting a valid template
sample_memory.set_prompt_template(prompt_template=example_template)
# Try setting an invalid template (bad jinja2)
template_bad_jinja = (
"{% for section, module in mammoth.items() %}"
'<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n'
"{{ module.value }}\n"
"</{{ section }}>"
"{% if not loop.last %}\n{% endif %}"
"{% endfor %" # Missing closing curly brace
)
with pytest.raises(ValueError):
sample_memory.set_prompt_template(prompt_template=template_bad_jinja)
# Try setting an invalid template (not compatible with memory structure)
template_bad_memory_structure = (
"{% for section, module in mammoth.items() %}"
'<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n'
"{{ module.value }}\n"
"</{{ section }}>"
"{% if not loop.last %}\n{% endif %}"
"{% endfor %}"
)
with pytest.raises(ValueError):
sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure)

View File

@@ -42,7 +42,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
print("TOOLS", [t.name for t in tools])
agent_state = client.get_agent(agent_state_test.id)
assert agent_state.name == "test_agent2"
for block in agent_state.memory.to_dict().values():
for block in agent_state.memory.to_dict()["memory"].values():
db_block = client.server.ms.get_block(block.get("id"))
assert db_block is not None, "memory block not persisted on agent create"
assert db_block.value == block.get("value"), "persisted block data does not match in-memory data"
@@ -134,7 +134,7 @@ def test_agent_with_shared_blocks(client):
)
assert isinstance(first_agent_state_test.memory, Memory)
first_blocks_dict = first_agent_state_test.memory.to_dict()
first_blocks_dict = first_agent_state_test.memory.to_dict()["memory"]
assert persona_block.id == first_blocks_dict.get("persona", {}).get("id")
assert human_block.id == first_blocks_dict.get("human", {}).get("id")
client.update_in_context_memory(first_agent_state_test.id, section="human", value="I'm an analyst therapist.")
@@ -148,7 +148,7 @@ def test_agent_with_shared_blocks(client):
)
assert isinstance(second_agent_state_test.memory, Memory)
second_blocks_dict = second_agent_state_test.memory.to_dict()
second_blocks_dict = second_agent_state_test.memory.to_dict()["memory"]
assert persona_block.id == second_blocks_dict.get("persona", {}).get("id")
assert human_block.id == second_blocks_dict.get("human", {}).get("id")
assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist."