From f802969664b2f996ac50756c3a70b7dda3a675cc Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 25 Oct 2024 18:01:33 -0700 Subject: [PATCH] fix: fix inconsistent name and label usage for blocks to resolve recursive validation issue (#1937) --- letta/agent.py | 2 +- letta/client/client.py | 18 +++++++------ letta/metadata.py | 4 +-- letta/schemas/block.py | 15 +++++------ letta/schemas/memory.py | 56 ++++++++++++++++++++--------------------- letta/server/server.py | 4 +-- tests/test_tools.py | 4 +-- 7 files changed, 50 insertions(+), 53 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index c865993d..e18e989a 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1153,7 +1153,7 @@ class Agent(BaseAgent): printd(f"skipping block update, unexpected value: {block_id=}") continue # TODO: we may want to update which columns we're updating from shared memory e.g. the limit - self.memory.update_block_value(name=block.get("label", ""), value=db_block.value) + self.memory.update_block_value(label=block.get("label", ""), value=db_block.value) # If the memory didn't update, we probably don't want to update the timestamp inside # For example, if we're doing a system prompt swap, this should probably be False diff --git a/letta/client/client.py b/letta/client/client.py index e21a0cb0..fd2c49eb 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -838,8 +838,8 @@ class RESTClient(AbstractClient): else: return [Block(**block) for block in response.json()] - def create_block(self, label: str, name: str, text: str) -> Block: # - request = CreateBlock(label=label, name=name, value=text) + def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: # + request = CreateBlock(label=label, value=text, template=template, name=name) response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create block: {response.text}") @@ -899,13 +899,13 @@ class RESTClient(AbstractClient): Create a human block template (saved human string to pre-fill `ChatMemory`) Args: - name (str): Name of the human block - text (str): Text of the human block + name (str): Name of the human block template + text (str): Text of the human block template Returns: human (Human): Human block """ - return self.create_block(label="human", name=name, text=text) + return self.create_block(label="human", name=name, text=text, template=True) def update_human(self, human_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Human: """ @@ -945,7 +945,7 @@ class RESTClient(AbstractClient): Returns: persona (Persona): Persona block """ - return self.create_block(label="persona", name=name, text=text) + return self.create_block(label="persona", name=name, text=text, template=True) def update_persona(self, persona_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Persona: """ @@ -2603,7 +2603,7 @@ class LocalClient(AbstractClient): """ return self.server.get_blocks(label=label, template=templates_only) - def create_block(self, name: str, text: str, label: Optional[str] = None) -> Block: # + def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: # """ Create a block @@ -2615,7 +2615,9 @@ class LocalClient(AbstractClient): Returns: block (Block): Created block """ - return self.server.create_block(CreateBlock(label=label, name=name, value=text, user_id=self.user_id), user_id=self.user_id) + return self.server.create_block( + CreateBlock(label=label, name=name, value=text, user_id=self.user_id, template=template), user_id=self.user_id + ) def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block: """ diff --git a/letta/metadata.py b/letta/metadata.py index b0150ac7..9c2761d2 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -306,9 +306,9 @@ class BlockModel(Base): id = Column(String, primary_key=True, nullable=False) value = Column(String, nullable=False) limit = Column(BIGINT) - name = Column(String, nullable=False) + name = Column(String) template = Column(Boolean, default=False) # True: listed as possible human/persona - label = Column(String) + label = Column(String, nullable=False) metadata_ = Column(JSON) description = Column(String) user_id = Column(String) diff --git a/letta/schemas/block.py b/letta/schemas/block.py index 8af2f47c..a7eedb0a 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -17,11 +17,14 @@ class BaseBlock(LettaBase, validate_assignment=True): value: Optional[str] = Field(None, description="Value of the block.") limit: int = Field(2000, description="Character limit of the block.") - name: Optional[str] = Field(None, description="Name of the block.") + # template data (optional) + name: Optional[str] = Field(None, description="Name of the block if it is a template.") template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).") - label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona').") - # metadat + # context window label + label: str = Field(None, description="Label of the block (e.g. 'human', 'persona') in the context window.") + + # metadata description: Optional[str] = Field(None, description="Description of the block.") metadata_: Optional[dict] = Field({}, description="Metadata of the block.") @@ -39,12 +42,6 @@ class BaseBlock(LettaBase, validate_assignment=True): raise e return self - @model_validator(mode="after") - def ensure_label(self) -> Self: - if not self.label: - self.label = self.name - return self - def __len__(self): return len(self.value) diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index a6c5ad02..ae3b34dd 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -61,7 +61,7 @@ class Memory(BaseModel, validate_assignment=True): """ - # Memory.memory is a dict mapping from memory block section to memory block. + # Memory.memory is a dict mapping from memory block label to memory block. memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.") # Memory.template is a Jinja2 template for compiling memory module into a prompt string. @@ -126,44 +126,42 @@ class Memory(BaseModel, validate_assignment=True): } def to_flat_dict(self): - """Convert to a dictionary that maps directly from block names to values""" + """Convert to a dictionary that maps directly from block label to values""" return {k: v.value for k, v in self.memory.items() if v is not None} - def list_block_names(self) -> List[str]: + def list_block_labels(self) -> List[str]: """Return a list of the block names held inside the memory object""" return list(self.memory.keys()) # TODO: these should actually be label, not name - def get_block(self, name: str) -> Block: + def get_block(self, label: str) -> Block: """Correct way to index into the memory.memory field, returns a Block""" - if name not in self.memory: - raise KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})") + if label not in self.memory: + raise KeyError(f"Block field {label} does not exist (available sections = {', '.join(list(self.memory.keys()))})") else: - return self.memory[name] + return self.memory[label] def get_blocks(self) -> List[Block]: """Return a list of the blocks held inside the memory object""" return list(self.memory.values()) - def link_block(self, name: str, block: Block, override: Optional[bool] = False): + def link_block(self, block: Block, override: Optional[bool] = False): """Link a new block to the memory object""" if not isinstance(block, Block): raise ValueError(f"Param block must be type Block (not {type(block)})") - if not isinstance(name, str): - raise ValueError(f"Name must be str (not type {type(name)})") - if not override and name in self.memory: - raise ValueError(f"Block with name {name} already exists") + if not override and block.label in self.memory: + raise ValueError(f"Block with label {block.label} already exists") - self.memory[name] = block + self.memory[block.label] = block - def update_block_value(self, name: str, value: str): + def update_block_value(self, label: str, value: str): """Update the value of a block""" - if name not in self.memory: - raise ValueError(f"Block with name {name} does not exist") + if label not in self.memory: + raise ValueError(f"Block with label {label} does not exist") if not isinstance(value, str): raise ValueError(f"Provided value must be a string") - self.memory[name].value = value + self.memory[label].value = value # TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names. @@ -192,41 +190,41 @@ class BasicBlockMemory(Memory): # assert block.name is not None and block.name != "", "each existing chat block must have a name" # self.link_block(name=block.name, block=block) assert block.label is not None and block.label != "", "each existing chat block must have a name" - self.link_block(name=block.label, block=block) + self.link_block(block=block) - def core_memory_append(self: "Agent", name: str, content: str) -> Optional[str]: # type: ignore + def core_memory_append(self: "Agent", label: str, content: str) -> Optional[str]: # type: ignore """ Append to the contents of core memory. Args: - name (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited (persona or human). content (str): Content to write to the memory. All unicode (including emojis) are supported. Returns: Optional[str]: None is always returned as this function does not produce a response. """ - current_value = str(self.memory.get_block(name).value) + current_value = str(self.memory.get_block(label).value) new_value = current_value + "\n" + str(content) - self.memory.update_block_value(name=name, value=new_value) + self.memory.update_block_value(label=label, value=new_value) return None - def core_memory_replace(self: "Agent", name: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore + def core_memory_replace(self: "Agent", label: str, old_content: str, new_content: str) -> Optional[str]: # type: ignore """ Replace the contents of core memory. To delete memories, use an empty string for new_content. Args: - name (str): Section of the memory to be edited (persona or human). + label (str): Section of the memory to be edited (persona or human). old_content (str): String to replace. Must be an exact match. new_content (str): Content to write to the memory. All unicode (including emojis) are supported. Returns: Optional[str]: None is always returned as this function does not produce a response. """ - current_value = str(self.memory.get_block(name).value) + current_value = str(self.memory.get_block(label).value) if old_content not in current_value: - raise ValueError(f"Old content '{old_content}' not found in memory block '{name}'") + raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'") new_value = current_value.replace(str(old_content), str(new_content)) - self.memory.update_block_value(name=name, value=new_value) + self.memory.update_block_value(label=label, value=new_value) return None @@ -245,8 +243,8 @@ class ChatMemory(BasicBlockMemory): limit (int): The character limit for each block. """ super().__init__() - self.link_block(name="persona", block=Block(name="persona", value=persona, limit=limit, label="persona")) - self.link_block(name="human", block=Block(name="human", value=human, limit=limit, label="human")) + self.link_block(block=Block(value=persona, limit=limit, label="persona")) + self.link_block(block=Block(value=human, limit=limit, label="human")) class UpdateMemory(BaseModel): diff --git a/letta/server/server.py b/letta/server/server.py index 98e0ef83..67a3b498 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1098,7 +1098,7 @@ class SyncServer(Server): block.value = request.value if request.value is not None else block.value block.name = request.name if request.name is not None else block.name self.ms.update_block(block=block) - return block + return self.ms.get_block(block_id=request.id) def delete_block(self, block_id: str): block = self.get_block(block_id) @@ -1413,7 +1413,7 @@ class SyncServer(Server): if value is None: continue if letta_agent.memory.get_block(key) != value: - letta_agent.memory.update_block_value(name=key, value=value) # update agent memory + letta_agent.memory.update_block_value(label=key, value=value) # update agent memory modified = True # If we modified the memory contents, we need to rebuild the memory block inside the system message diff --git a/tests/test_tools.py b/tests/test_tools.py index b8507d65..e2ddf999 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -134,8 +134,8 @@ def test_create_agent_tool(client): str: The agent that was deleted. """ - self.memory.update_block_value(name="human", value="") - self.memory.update_block_value(name="persona", value="") + self.memory.update_block_value(label="human", value="") + self.memory.update_block_value(label="persona", value="") print("UPDATED MEMORY", self.memory.memory) return None