fix: fix inconsistent name and label usage for blocks to resolve recursive validation issue (#1937)

This commit is contained in:
Sarah Wooders
2024-10-25 18:01:33 -07:00
committed by GitHub
parent 2505eba25f
commit f802969664
7 changed files with 50 additions and 53 deletions

View File

@@ -1153,7 +1153,7 @@ class Agent(BaseAgent):
printd(f"skipping block update, unexpected value: {block_id=}")
continue
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
self.memory.update_block_value(name=block.get("label", ""), value=db_block.value)
self.memory.update_block_value(label=block.get("label", ""), value=db_block.value)
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False

View File

@@ -838,8 +838,8 @@ class RESTClient(AbstractClient):
else:
return [Block(**block) for block in response.json()]
def create_block(self, label: str, name: str, text: str) -> Block: #
request = CreateBlock(label=label, name=name, value=text)
def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: #
request = CreateBlock(label=label, value=text, template=template, name=name)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create block: {response.text}")
@@ -899,13 +899,13 @@ class RESTClient(AbstractClient):
Create a human block template (saved human string to pre-fill `ChatMemory`)
Args:
name (str): Name of the human block
text (str): Text of the human block
name (str): Name of the human block template
text (str): Text of the human block template
Returns:
human (Human): Human block
"""
return self.create_block(label="human", name=name, text=text)
return self.create_block(label="human", name=name, text=text, template=True)
def update_human(self, human_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Human:
"""
@@ -945,7 +945,7 @@ class RESTClient(AbstractClient):
Returns:
persona (Persona): Persona block
"""
return self.create_block(label="persona", name=name, text=text)
return self.create_block(label="persona", name=name, text=text, template=True)
def update_persona(self, persona_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Persona:
"""
@@ -2603,7 +2603,7 @@ class LocalClient(AbstractClient):
"""
return self.server.get_blocks(label=label, template=templates_only)
def create_block(self, name: str, text: str, label: Optional[str] = None) -> Block: #
def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: #
"""
Create a block
@@ -2615,7 +2615,9 @@ class LocalClient(AbstractClient):
Returns:
block (Block): Created block
"""
return self.server.create_block(CreateBlock(label=label, name=name, value=text, user_id=self.user_id), user_id=self.user_id)
return self.server.create_block(
CreateBlock(label=label, name=name, value=text, user_id=self.user_id, template=template), user_id=self.user_id
)
def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
"""

View File

@@ -306,9 +306,9 @@ class BlockModel(Base):
id = Column(String, primary_key=True, nullable=False)
value = Column(String, nullable=False)
limit = Column(BIGINT)
name = Column(String, nullable=False)
name = Column(String)
template = Column(Boolean, default=False) # True: listed as possible human/persona
label = Column(String)
label = Column(String, nullable=False)
metadata_ = Column(JSON)
description = Column(String)
user_id = Column(String)

View File

@@ -17,11 +17,14 @@ class BaseBlock(LettaBase, validate_assignment=True):
value: Optional[str] = Field(None, description="Value of the block.")
limit: int = Field(2000, description="Character limit of the block.")
name: Optional[str] = Field(None, description="Name of the block.")
# template data (optional)
name: Optional[str] = Field(None, description="Name of the block if it is a template.")
template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).")
label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona').")
# metadat
# context window label
label: str = Field(None, description="Label of the block (e.g. 'human', 'persona') in the context window.")
# metadata
description: Optional[str] = Field(None, description="Description of the block.")
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")
@@ -39,12 +42,6 @@ class BaseBlock(LettaBase, validate_assignment=True):
raise e
return self
@model_validator(mode="after")
def ensure_label(self) -> Self:
if not self.label:
self.label = self.name
return self
def __len__(self):
return len(self.value)

View File

@@ -61,7 +61,7 @@ class Memory(BaseModel, validate_assignment=True):
"""
# Memory.memory is a dict mapping from memory block section to memory block.
# Memory.memory is a dict mapping from memory block label to memory block.
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")
# Memory.template is a Jinja2 template for compiling memory module into a prompt string.
@@ -126,44 +126,42 @@ class Memory(BaseModel, validate_assignment=True):
}
def to_flat_dict(self):
"""Convert to a dictionary that maps directly from block names to values"""
"""Convert to a dictionary that maps directly from block label to values"""
return {k: v.value for k, v in self.memory.items() if v is not None}
def list_block_names(self) -> List[str]:
def list_block_labels(self) -> List[str]:
"""Return a list of the block names held inside the memory object"""
return list(self.memory.keys())
# TODO: these should actually be label, not name
def get_block(self, name: str) -> Block:
def get_block(self, label: str) -> Block:
"""Correct way to index into the memory.memory field, returns a Block"""
if name not in self.memory:
raise KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
if label not in self.memory:
raise KeyError(f"Block field {label} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
else:
return self.memory[name]
return self.memory[label]
def get_blocks(self) -> List[Block]:
"""Return a list of the blocks held inside the memory object"""
return list(self.memory.values())
def link_block(self, name: str, block: Block, override: Optional[bool] = False):
def link_block(self, block: Block, override: Optional[bool] = False):
"""Link a new block to the memory object"""
if not isinstance(block, Block):
raise ValueError(f"Param block must be type Block (not {type(block)})")
if not isinstance(name, str):
raise ValueError(f"Name must be str (not type {type(name)})")
if not override and name in self.memory:
raise ValueError(f"Block with name {name} already exists")
if not override and block.label in self.memory:
raise ValueError(f"Block with label {block.label} already exists")
self.memory[name] = block
self.memory[block.label] = block
def update_block_value(self, name: str, value: str):
def update_block_value(self, label: str, value: str):
"""Update the value of a block"""
if name not in self.memory:
raise ValueError(f"Block with name {name} does not exist")
if label not in self.memory:
raise ValueError(f"Block with label {label} does not exist")
if not isinstance(value, str):
raise ValueError(f"Provided value must be a string")
self.memory[name].value = value
self.memory[label].value = value
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
@@ -192,41 +190,41 @@ class BasicBlockMemory(Memory):
# assert block.name is not None and block.name != "", "each existing chat block must have a name"
# self.link_block(name=block.name, block=block)
assert block.label is not None and block.label != "", "each existing chat block must have a name"
self.link_block(name=block.label, block=block)
self.link_block(block=block)
def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore
def core_memory_append(self: "Agent", label: str, content: str) -> Optional[str]: # type: ignore
"""
Append to the contents of core memory.
Args:
name (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited (persona or human).
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
current_value = str(self.memory.get_block(name).value)
current_value = str(self.memory.get_block(label).value)
new_value = current_value + "\n" + str(content)
self.memory.update_block_value(name=name, value=new_value)
self.memory.update_block_value(label=label, value=new_value)
return None
def core_memory_replace(self: "Agent", name: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
def core_memory_replace(self: "Agent", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
name (str): Section of the memory to be edited (persona or human).
label (str): Section of the memory to be edited (persona or human).
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
current_value = str(self.memory.get_block(name).value)
current_value = str(self.memory.get_block(label).value)
if old_content not in current_value:
raise ValueError(f"Old content '{old_content}' not found in memory block '{name}'")
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
new_value = current_value.replace(str(old_content), str(new_content))
self.memory.update_block_value(name=name, value=new_value)
self.memory.update_block_value(label=label, value=new_value)
return None
@@ -245,8 +243,8 @@ class ChatMemory(BasicBlockMemory):
limit (int): The character limit for each block.
"""
super().__init__()
self.link_block(name="persona", block=Block(name="persona", value=persona, limit=limit, label="persona"))
self.link_block(name="human", block=Block(name="human", value=human, limit=limit, label="human"))
self.link_block(block=Block(value=persona, limit=limit, label="persona"))
self.link_block(block=Block(value=human, limit=limit, label="human"))
class UpdateMemory(BaseModel):

View File

@@ -1098,7 +1098,7 @@ class SyncServer(Server):
block.value = request.value if request.value is not None else block.value
block.name = request.name if request.name is not None else block.name
self.ms.update_block(block=block)
return block
return self.ms.get_block(block_id=request.id)
def delete_block(self, block_id: str):
block = self.get_block(block_id)
@@ -1413,7 +1413,7 @@ class SyncServer(Server):
if value is None:
continue
if letta_agent.memory.get_block(key) != value:
letta_agent.memory.update_block_value(name=key, value=value) # update agent memory
letta_agent.memory.update_block_value(label=key, value=value) # update agent memory
modified = True
# If we modified the memory contents, we need to rebuild the memory block inside the system message

View File

@@ -134,8 +134,8 @@ def test_create_agent_tool(client):
str: The agent that was deleted.
"""
self.memory.update_block_value(name="human", value="")
self.memory.update_block_value(name="persona", value="")
self.memory.update_block_value(label="human", value="")
self.memory.update_block_value(label="persona", value="")
print("UPDATED MEMORY", self.memory.memory)
return None