refactor: use jinja2 templates for Memory.compile instead of writing Python code (#1687)
This commit is contained in:
@@ -698,7 +698,7 @@ class Agent(object):
|
||||
# TODO: ensure we're passing in metadata store from all surfaces
|
||||
if ms is not None:
|
||||
should_update = False
|
||||
for block in self.agent_state.memory.to_dict().values():
|
||||
for block in self.agent_state.memory.to_dict()["memory"].values():
|
||||
if not block.get("template", False):
|
||||
should_update = True
|
||||
if should_update:
|
||||
@@ -1020,7 +1020,7 @@ class Agent(object):
|
||||
return
|
||||
|
||||
if ms:
|
||||
for block in self.memory.to_dict().values():
|
||||
for block in self.memory.to_dict()["memory"].values():
|
||||
if block.get("templates", False):
|
||||
# we don't expect to update shared memory blocks that
|
||||
# are templates. this is something we could update in the
|
||||
@@ -1225,7 +1225,7 @@ def save_agent_memory(agent: Agent, ms: MetadataStore):
|
||||
NOTE: we are assuming agent.update_state has already been called.
|
||||
"""
|
||||
|
||||
for block_dict in agent.memory.to_dict().values():
|
||||
for block_dict in agent.memory.to_dict()["memory"].values():
|
||||
# TODO: block creation should happen in one place to enforce these sort of constraints consistently.
|
||||
if block_dict.get("user_id", None) is None:
|
||||
block_dict["user_id"] = agent.agent_state.user_id
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
@@ -190,9 +189,9 @@ def run_agent_loop(
|
||||
|
||||
elif user_input.lower() == "/memory":
|
||||
print(f"\nDumping memory contents:\n")
|
||||
print(f"{str(memgpt_agent.memory)}")
|
||||
print(f"{str(memgpt_agent.persistence_manager.archival_memory)}")
|
||||
print(f"{str(memgpt_agent.persistence_manager.recall_memory)}")
|
||||
print(f"{memgpt_agent.memory.compile()}")
|
||||
print(f"{memgpt_agent.persistence_manager.archival_memory.compile()}")
|
||||
print(f"{memgpt_agent.persistence_manager.recall_memory.compile()}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/model":
|
||||
|
||||
206
memgpt/memory.py
206
memgpt/memory.py
@@ -19,114 +19,6 @@ from memgpt.utils import (
|
||||
validate_date_format,
|
||||
)
|
||||
|
||||
# class MemoryModule(BaseModel):
|
||||
# """Base class for memory modules"""
|
||||
#
|
||||
# description: Optional[str] = None
|
||||
# limit: int = 2000
|
||||
# value: Optional[Union[List[str], str]] = None
|
||||
#
|
||||
# def __setattr__(self, name, value):
|
||||
# """Run validation if self.value is updated"""
|
||||
# super().__setattr__(name, value)
|
||||
# if name == "value":
|
||||
# # run validation
|
||||
# self.__class__.validate(self.dict(exclude_unset=True))
|
||||
#
|
||||
# @validator("value", always=True)
|
||||
# def check_value_length(cls, v, values):
|
||||
# if v is not None:
|
||||
# # Fetching the limit from the values dictionary
|
||||
# limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
|
||||
#
|
||||
# # Check if the value exceeds the limit
|
||||
# if isinstance(v, str):
|
||||
# length = len(v)
|
||||
# elif isinstance(v, list):
|
||||
# length = sum(len(item) for item in v)
|
||||
# else:
|
||||
# raise ValueError("Value must be either a string or a list of strings.")
|
||||
#
|
||||
# if length > limit:
|
||||
# error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
|
||||
# # TODO: add archival memory error?
|
||||
# raise ValueError(error_msg)
|
||||
# return v
|
||||
#
|
||||
# def __len__(self):
|
||||
# return len(str(self))
|
||||
#
|
||||
# def __str__(self) -> str:
|
||||
# if isinstance(self.value, list):
|
||||
# return ",".join(self.value)
|
||||
# elif isinstance(self.value, str):
|
||||
# return self.value
|
||||
# else:
|
||||
# return ""
|
||||
#
|
||||
#
|
||||
# class BaseMemory:
|
||||
#
|
||||
# def __init__(self):
|
||||
# self.memory = {}
|
||||
#
|
||||
# @classmethod
|
||||
# def load(cls, state: dict):
|
||||
# """Load memory from dictionary object"""
|
||||
# obj = cls()
|
||||
# for key, value in state.items():
|
||||
# obj.memory[key] = MemoryModule(**value)
|
||||
# return obj
|
||||
#
|
||||
# def __str__(self) -> str:
|
||||
# """Representation of the memory in-context"""
|
||||
# section_strs = []
|
||||
# for section, module in self.memory.items():
|
||||
# section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
|
||||
# return "\n".join(section_strs)
|
||||
#
|
||||
# def to_dict(self):
|
||||
# """Convert to dictionary representation"""
|
||||
# return {key: value.dict() for key, value in self.memory.items()}
|
||||
#
|
||||
#
|
||||
# class ChatMemory(BaseMemory):
|
||||
#
|
||||
# def __init__(self, persona: str, human: str, limit: int = 2000):
|
||||
# self.memory = {
|
||||
# "persona": MemoryModule(name="persona", value=persona, limit=limit),
|
||||
# "human": MemoryModule(name="human", value=human, limit=limit),
|
||||
# }
|
||||
#
|
||||
# def core_memory_append(self, name: str, content: str) -> Optional[str]:
|
||||
# """
|
||||
# Append to the contents of core memory.
|
||||
#
|
||||
# Args:
|
||||
# name (str): Section of the memory to be edited (persona or human).
|
||||
# content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
#
|
||||
# Returns:
|
||||
# Optional[str]: None is always returned as this function does not produce a response.
|
||||
# """
|
||||
# self.memory[name].value += "\n" + content
|
||||
# return None
|
||||
#
|
||||
# def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
# """
|
||||
# Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
#
|
||||
# Args:
|
||||
# name (str): Section of the memory to be edited (persona or human).
|
||||
# old_content (str): String to replace. Must be an exact match.
|
||||
# new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
#
|
||||
# Returns:
|
||||
# Optional[str]: None is always returned as this function does not produce a response.
|
||||
# """
|
||||
# self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
|
||||
# return None
|
||||
|
||||
|
||||
def get_memory_functions(cls: Memory) -> List[callable]:
|
||||
"""Get memory functions for a memory class"""
|
||||
@@ -151,94 +43,6 @@ def get_memory_functions(cls: Memory) -> List[callable]:
|
||||
return functions
|
||||
|
||||
|
||||
# class CoreMemory(object):
|
||||
# """Held in-context inside the system message
|
||||
#
|
||||
# Core Memory: Refers to the system block, which provides essential, foundational context to the AI.
|
||||
# This includes the persona information, essential user details,
|
||||
# and any other baseline data you deem necessary for the AI's basic functioning.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True):
|
||||
# self.persona = persona
|
||||
# self.human = human
|
||||
# self.persona_char_limit = persona_char_limit
|
||||
# self.human_char_limit = human_char_limit
|
||||
#
|
||||
# # affects the error message the AI will see on overflow inserts
|
||||
# self.archival_memory_exists = archival_memory_exists
|
||||
#
|
||||
# def __repr__(self) -> str:
|
||||
# return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
|
||||
#
|
||||
# def to_dict(self):
|
||||
# return {
|
||||
# "persona": self.persona,
|
||||
# "human": self.human,
|
||||
# }
|
||||
#
|
||||
# @classmethod
|
||||
# def load(cls, state):
|
||||
# return cls(state["persona"], state["human"])
|
||||
#
|
||||
# def edit_persona(self, new_persona):
|
||||
# if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
|
||||
# error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})."
|
||||
# if self.archival_memory_exists:
|
||||
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
|
||||
# raise ValueError(error_msg)
|
||||
#
|
||||
# self.persona = new_persona
|
||||
# return len(self.persona)
|
||||
#
|
||||
# def edit_human(self, new_human):
|
||||
# if self.human_char_limit and len(new_human) > self.human_char_limit:
|
||||
# error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})."
|
||||
# if self.archival_memory_exists:
|
||||
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
|
||||
# raise ValueError(error_msg)
|
||||
#
|
||||
# self.human = new_human
|
||||
# return len(self.human)
|
||||
#
|
||||
# def edit(self, field, content):
|
||||
# if field == "persona":
|
||||
# return self.edit_persona(content)
|
||||
# elif field == "human":
|
||||
# return self.edit_human(content)
|
||||
# else:
|
||||
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
#
|
||||
# def edit_append(self, field, content, sep="\n"):
|
||||
# if field == "persona":
|
||||
# new_content = self.persona + sep + content
|
||||
# return self.edit_persona(new_content)
|
||||
# elif field == "human":
|
||||
# new_content = self.human + sep + content
|
||||
# return self.edit_human(new_content)
|
||||
# else:
|
||||
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
#
|
||||
# def edit_replace(self, field, old_content, new_content):
|
||||
# if len(old_content) == 0:
|
||||
# raise ValueError("old_content cannot be an empty string (must specify old_content to replace)")
|
||||
#
|
||||
# if field == "persona":
|
||||
# if old_content in self.persona:
|
||||
# new_persona = self.persona.replace(old_content, new_content)
|
||||
# return self.edit_persona(new_persona)
|
||||
# else:
|
||||
# raise ValueError("Content not found in persona (make sure to use exact string)")
|
||||
# elif field == "human":
|
||||
# if old_content in self.human:
|
||||
# new_human = self.human.replace(old_content, new_content)
|
||||
# return self.edit_human(new_human)
|
||||
# else:
|
||||
# raise ValueError("Content not found in human (make sure to use exact string)")
|
||||
# else:
|
||||
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
|
||||
|
||||
def _format_summary_history(message_history: List[Message]):
|
||||
# TODO use existing prompt formatters for this (eg ChatML)
|
||||
return "\n".join([f"{m.role}: {m.text}" for m in message_history])
|
||||
@@ -308,7 +112,7 @@ class ArchivalMemory(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
def compile(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
@@ -322,7 +126,7 @@ class RecallMemory(ABC):
|
||||
"""Search messages between start_date and end_date in recall memory"""
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
def compile(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -350,7 +154,7 @@ class DummyRecallMemory(RecallMemory):
|
||||
def __len__(self):
|
||||
return len(self._message_logs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def compile(self) -> str:
|
||||
# don't dump all the conversations, just statistics
|
||||
system_count = user_count = assistant_count = function_count = other_count = 0
|
||||
for msg in self._message_logs:
|
||||
@@ -467,7 +271,7 @@ class BaseRecallMemory(RecallMemory):
|
||||
results_json = [message.to_openai_dict_search_results() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def compile(self) -> str:
|
||||
total = self.storage.size()
|
||||
system_count = self.storage.size(filters={"role": "system"})
|
||||
user_count = self.storage.size(filters={"role": "user"})
|
||||
@@ -597,7 +401,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
print("Archival search error", e)
|
||||
raise e
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def compile(self) -> str:
|
||||
limit = 10
|
||||
passages = []
|
||||
for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
@@ -14,7 +14,7 @@ class BaseBlock(MemGPTBase, validate_assignment=True):
|
||||
__id_prefix__ = "block"
|
||||
|
||||
# data value
|
||||
value: Optional[Union[List[str], str]] = Field(None, description="Value of the block.")
|
||||
value: Optional[str] = Field(None, description="Value of the block.")
|
||||
limit: int = Field(2000, description="Character limit of the block.")
|
||||
|
||||
name: Optional[str] = Field(None, description="Name of the block.")
|
||||
|
||||
@@ -1,34 +1,81 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from jinja2 import Template, TemplateSyntaxError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Forward referencing to avoid circular import with Agent -> Memory -> Agent
|
||||
if TYPE_CHECKING:
|
||||
from memgpt.agent import Agent
|
||||
|
||||
from memgpt.schemas.block import Block
|
||||
|
||||
|
||||
class Memory(BaseModel, validate_assignment=True):
|
||||
"""Represents the in-context memory of the agent"""
|
||||
|
||||
# Private variable to avoid assignments with incorrect types
|
||||
# Memory.memory is a dict mapping from memory block section to memory block.
|
||||
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")
|
||||
|
||||
# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
|
||||
prompt_template: str = Field(
|
||||
default="{% for section, block in memory.items() %}"
|
||||
'<{{ section }} characters="{{ block.value|length }}/{{ block.limit }}">\n'
|
||||
"{{ block.value }}\n"
|
||||
"</{{ section }}>"
|
||||
"{% if not loop.last %}\n{% endif %}"
|
||||
"{% endfor %}",
|
||||
description="Jinja2 template for compiling memory blocks into a prompt string",
|
||||
)
|
||||
|
||||
def get_prompt_template(self) -> str:
|
||||
"""Return the current Jinja2 template string."""
|
||||
return str(self.prompt_template)
|
||||
|
||||
def set_prompt_template(self, prompt_template: str):
|
||||
"""
|
||||
Set a new Jinja2 template string.
|
||||
Validates the template syntax and compatibility with current memory structure.
|
||||
"""
|
||||
try:
|
||||
# Validate Jinja2 syntax
|
||||
Template(prompt_template)
|
||||
|
||||
# Validate compatibility with current memory structure
|
||||
test_render = Template(prompt_template).render(memory=self.memory)
|
||||
|
||||
# If we get here, the template is valid and compatible
|
||||
self.prompt_template = prompt_template
|
||||
except TemplateSyntaxError as e:
|
||||
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def load(cls, state: dict):
|
||||
"""Load memory from dictionary object"""
|
||||
obj = cls()
|
||||
for key, value in state.items():
|
||||
obj.memory[key] = Block(**value)
|
||||
if len(state.keys()) == 2 and "memory" in state and "prompt_template" in state:
|
||||
# New format
|
||||
obj.prompt_template = state["prompt_template"]
|
||||
for key, value in state["memory"].items():
|
||||
obj.memory[key] = Block(**value)
|
||||
else:
|
||||
# Old format (pre-template)
|
||||
for key, value in state.items():
|
||||
obj.memory[key] = Block(**value)
|
||||
return obj
|
||||
|
||||
def compile(self) -> str:
|
||||
"""Generate a string representation of the memory in-context"""
|
||||
section_strs = []
|
||||
for section, module in self.memory.items():
|
||||
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
|
||||
return "\n".join(section_strs)
|
||||
"""Generate a string representation of the memory in-context using the Jinja2 template"""
|
||||
template = Template(self.prompt_template)
|
||||
return template.render(memory=self.memory)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary representation"""
|
||||
return {key: value.dict() for key, value in self.memory.items()}
|
||||
return {
|
||||
"memory": {key: value.model_dump() for key, value in self.memory.items()},
|
||||
"prompt_template": self.prompt_template,
|
||||
}
|
||||
|
||||
def to_flat_dict(self):
|
||||
"""Convert to a dictionary that maps directly from block names to values"""
|
||||
@@ -41,7 +88,7 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
def get_block(self, name: str) -> Block:
|
||||
"""Correct way to index into the memory.memory field, returns a Block"""
|
||||
if name not in self.memory:
|
||||
return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
|
||||
raise KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
|
||||
else:
|
||||
return self.memory[name]
|
||||
|
||||
@@ -56,19 +103,20 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
|
||||
self.memory[name] = block
|
||||
|
||||
def update_block_value(self, name: str, value: Union[List[str], str]):
|
||||
def update_block_value(self, name: str, value: str):
|
||||
"""Update the value of a block"""
|
||||
if name not in self.memory:
|
||||
raise ValueError(f"Block with name {name} does not exist")
|
||||
if not (isinstance(value, str) or (isinstance(value, list) and all(isinstance(v, str) for v in value))):
|
||||
raise ValueError(f"Provided value must be a string or list of strings")
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"Provided value must be a string")
|
||||
|
||||
self.memory[name].value = value
|
||||
|
||||
|
||||
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
|
||||
class BaseChatMemory(Memory):
|
||||
def core_memory_append(self, name: str, content: str) -> Optional[str]:
|
||||
|
||||
def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
@@ -84,7 +132,7 @@ class BaseChatMemory(Memory):
|
||||
self.memory.update_block_value(name=name, value=new_value)
|
||||
return None
|
||||
|
||||
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
def core_memory_replace(self: "Agent", name: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
|
||||
12
poetry.lock
generated
12
poetry.lock
generated
@@ -2557,7 +2557,7 @@ testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"]
|
||||
name = "jinja2"
|
||||
version = "3.1.4"
|
||||
description = "A very fast and expressive template engine."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"},
|
||||
@@ -3367,7 +3367,7 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
||||
name = "markupsafe"
|
||||
version = "2.1.5"
|
||||
description = "Safely add untrusted strings to HTML/XML markup."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
|
||||
@@ -3502,6 +3502,7 @@ python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "milvus_lite-2.4.9-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d3e617b3d68c09ad656d54bc3d8cc4ef6ef56c54015e1563d4fe4bcec6b7c90a"},
|
||||
{file = "milvus_lite-2.4.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6e7029282d6829b277ebb92f64e2370be72b938e34770e1eb649346bda5d1d7f"},
|
||||
{file = "milvus_lite-2.4.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9b8e991e4e433596f6a399a165c1a506f823ec9133332e03d7f8a114bff4550d"},
|
||||
{file = "milvus_lite-2.4.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:7f53e674602101cfbcf0a4a59d19eaa139dfd5580639f3040ad73d901f24fc0b"},
|
||||
]
|
||||
|
||||
@@ -7129,11 +7130,6 @@ files = [
|
||||
{file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"},
|
||||
{file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"},
|
||||
{file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"},
|
||||
{file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"},
|
||||
{file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"},
|
||||
{file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"},
|
||||
{file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"},
|
||||
{file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -7965,4 +7961,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.10"
|
||||
content-hash = "3195feb8a0715fb8a8a191e6c402ae6c01f991921ac7a5629a333e0556d7d02a"
|
||||
content-hash = "835cba3934be7d79b7f2eb2c5c4c81af2eebe432ed55b186b1616022605e2f70"
|
||||
|
||||
@@ -67,6 +67,7 @@ crewai-tools = {version = "^0.8.3", optional = true}
|
||||
docker = {version = "^7.1.0", optional = true}
|
||||
tiktoken = "^0.7.0"
|
||||
nltk = "^3.8.1"
|
||||
jinja2 = "^3.1.4"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["llama-index-embeddings-huggingface"]
|
||||
|
||||
@@ -18,7 +18,7 @@ def test_create_chat_memory():
|
||||
|
||||
def test_dump_memory_as_json(sample_memory: Memory):
|
||||
"""Test dumping ChatMemory as JSON compatible dictionary"""
|
||||
memory_dict = sample_memory.to_dict()
|
||||
memory_dict = sample_memory.to_dict()["memory"]
|
||||
assert isinstance(memory_dict, dict)
|
||||
assert "persona" in memory_dict
|
||||
assert memory_dict["persona"]["value"] == "Chat Agent"
|
||||
@@ -26,7 +26,7 @@ def test_dump_memory_as_json(sample_memory: Memory):
|
||||
|
||||
def test_load_memory_from_json(sample_memory: Memory):
|
||||
"""Test loading ChatMemory from a JSON compatible dictionary"""
|
||||
memory_dict = sample_memory.to_dict()
|
||||
memory_dict = sample_memory.to_dict()["memory"]
|
||||
print(memory_dict)
|
||||
new_memory = Memory.load(memory_dict)
|
||||
assert new_memory.get_block("persona").value == "Chat Agent"
|
||||
@@ -63,3 +63,68 @@ def test_memory_limit_validation(sample_memory: Memory):
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
sample_memory.get_block("persona").value = "x" * 3000
|
||||
|
||||
|
||||
def test_memory_jinja2_template_load(sample_memory: Memory):
|
||||
"""Test loading a memory with and without a jinja2 template"""
|
||||
|
||||
# Test loading a memory with a template
|
||||
memory_dict = sample_memory.to_dict()
|
||||
memory_dict["prompt_template"] = sample_memory.get_prompt_template()
|
||||
new_memory = Memory.load(memory_dict)
|
||||
assert new_memory.get_prompt_template() == sample_memory.get_prompt_template()
|
||||
|
||||
# Test loading a memory without a template (old format)
|
||||
memory_dict = sample_memory.to_dict()
|
||||
memory_dict_old_format = memory_dict["memory"]
|
||||
new_memory = Memory.load(memory_dict_old_format)
|
||||
assert new_memory.get_prompt_template() is not None # Ensure a default template is set
|
||||
assert new_memory.to_dict()["memory"] == memory_dict_old_format
|
||||
|
||||
|
||||
def test_memory_jinja2_template(sample_memory: Memory):
|
||||
"""Test to make sure the jinja2 template string is equivalent to the old __repr__ method"""
|
||||
|
||||
def old_repr(self: Memory) -> str:
|
||||
"""Generate a string representation of the memory in-context"""
|
||||
section_strs = []
|
||||
for section, module in self.memory.items():
|
||||
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
|
||||
return "\n".join(section_strs)
|
||||
|
||||
old_repr_str = old_repr(sample_memory)
|
||||
new_repr_str = sample_memory.compile()
|
||||
assert new_repr_str == old_repr_str, f"Expected '{old_repr_str}' to be '{new_repr_str}'"
|
||||
|
||||
|
||||
def test_memory_jinja2_set_template(sample_memory: Memory):
|
||||
"""Test setting the template for the memory"""
|
||||
|
||||
example_template = sample_memory.get_prompt_template()
|
||||
|
||||
# Try setting a valid template
|
||||
sample_memory.set_prompt_template(prompt_template=example_template)
|
||||
|
||||
# Try setting an invalid template (bad jinja2)
|
||||
template_bad_jinja = (
|
||||
"{% for section, module in mammoth.items() %}"
|
||||
'<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n'
|
||||
"{{ module.value }}\n"
|
||||
"</{{ section }}>"
|
||||
"{% if not loop.last %}\n{% endif %}"
|
||||
"{% endfor %" # Missing closing curly brace
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
sample_memory.set_prompt_template(prompt_template=template_bad_jinja)
|
||||
|
||||
# Try setting an invalid template (not compatible with memory structure)
|
||||
template_bad_memory_structure = (
|
||||
"{% for section, module in mammoth.items() %}"
|
||||
'<{{ section }} characters="{{ module.value|length }}/{{ module.limit }}">\n'
|
||||
"{{ module.value }}\n"
|
||||
"</{{ section }}>"
|
||||
"{% if not loop.last %}\n{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure)
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
|
||||
print("TOOLS", [t.name for t in tools])
|
||||
agent_state = client.get_agent(agent_state_test.id)
|
||||
assert agent_state.name == "test_agent2"
|
||||
for block in agent_state.memory.to_dict().values():
|
||||
for block in agent_state.memory.to_dict()["memory"].values():
|
||||
db_block = client.server.ms.get_block(block.get("id"))
|
||||
assert db_block is not None, "memory block not persisted on agent create"
|
||||
assert db_block.value == block.get("value"), "persisted block data does not match in-memory data"
|
||||
@@ -134,7 +134,7 @@ def test_agent_with_shared_blocks(client):
|
||||
)
|
||||
assert isinstance(first_agent_state_test.memory, Memory)
|
||||
|
||||
first_blocks_dict = first_agent_state_test.memory.to_dict()
|
||||
first_blocks_dict = first_agent_state_test.memory.to_dict()["memory"]
|
||||
assert persona_block.id == first_blocks_dict.get("persona", {}).get("id")
|
||||
assert human_block.id == first_blocks_dict.get("human", {}).get("id")
|
||||
client.update_in_context_memory(first_agent_state_test.id, section="human", value="I'm an analyst therapist.")
|
||||
@@ -148,7 +148,7 @@ def test_agent_with_shared_blocks(client):
|
||||
)
|
||||
|
||||
assert isinstance(second_agent_state_test.memory, Memory)
|
||||
second_blocks_dict = second_agent_state_test.memory.to_dict()
|
||||
second_blocks_dict = second_agent_state_test.memory.to_dict()["memory"]
|
||||
assert persona_block.id == second_blocks_dict.get("persona", {}).get("id")
|
||||
assert human_block.id == second_blocks_dict.get("human", {}).get("id")
|
||||
assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist."
|
||||
|
||||
Reference in New Issue
Block a user