feat: routes for adding/linking new memory blocks to agents + unlinking blocks from agents (#2083)
This commit is contained in:
@@ -1565,6 +1565,41 @@ class RESTClient(AbstractClient):
|
||||
# Parse and return the deleted organization
|
||||
return Organization(**response.json())
|
||||
|
||||
def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory:
|
||||
|
||||
# @router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label")
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/label",
|
||||
headers=self.headers,
|
||||
json={"current_label": current_label, "new_label": new_label},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update agent memory label: {response.text}")
|
||||
return Memory(**response.json())
|
||||
|
||||
def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory:
|
||||
|
||||
# @router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block")
|
||||
response = requests.post(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block",
|
||||
headers=self.headers,
|
||||
json=create_block.model_dump(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to add agent memory block: {response.text}")
|
||||
return Memory(**response.json())
|
||||
|
||||
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
|
||||
|
||||
# @router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block")
|
||||
response = requests.delete(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block/{block_label}",
|
||||
headers=self.headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to remove agent memory block: {response.text}")
|
||||
return Memory(**response.json())
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
"""
|
||||
@@ -2773,3 +2808,18 @@ class LocalClient(AbstractClient):
|
||||
|
||||
def delete_org(self, org_id: str) -> Organization:
|
||||
return self.server.organization_manager.delete_organization_by_id(org_id=org_id)
|
||||
|
||||
def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory:
|
||||
return self.server.update_agent_memory_label(
|
||||
user_id=self.user_id, agent_id=agent_id, current_block_label=current_label, new_block_label=new_label
|
||||
)
|
||||
|
||||
def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory:
|
||||
block_req = Block(**create_block.model_dump())
|
||||
block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req)
|
||||
# Link the block to the agent
|
||||
updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id)
|
||||
return updated_memory
|
||||
|
||||
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
|
||||
return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
@@ -95,6 +95,13 @@ class BlockCreate(BaseBlock):
|
||||
label: str = Field(..., description="Label of the block.")
|
||||
|
||||
|
||||
class BlockLabelUpdate(BaseModel):
|
||||
"""Update the label of a block"""
|
||||
|
||||
current_label: str = Field(..., description="Current label of the block.")
|
||||
new_label: str = Field(..., description="New label of the block.")
|
||||
|
||||
|
||||
class CreatePersona(BlockCreate):
|
||||
"""Create a persona block"""
|
||||
|
||||
|
||||
@@ -158,6 +158,13 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
|
||||
self.memory[block.label] = block
|
||||
|
||||
def unlink_block(self, block_label: str) -> Block:
|
||||
"""Unlink a block from the memory object"""
|
||||
if block_label not in self.memory:
|
||||
raise ValueError(f"Block with label {block_label} does not exist")
|
||||
|
||||
return self.memory.pop(block_label)
|
||||
|
||||
def update_block_value(self, label: str, value: str):
|
||||
"""Update the value of a block"""
|
||||
if label not in self.memory:
|
||||
@@ -167,6 +174,19 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
|
||||
self.memory[label].value = value
|
||||
|
||||
def update_block_label(self, current_label: str, new_label: str):
|
||||
"""Update the label of a block"""
|
||||
if current_label not in self.memory:
|
||||
raise ValueError(f"Block with label {current_label} does not exist")
|
||||
if not isinstance(new_label, str):
|
||||
raise ValueError(f"Provided new label must be a string")
|
||||
|
||||
# First change the label of the block
|
||||
self.memory[current_label].label = new_label
|
||||
|
||||
# Then swap the block to the new label
|
||||
self.memory[new_label] = self.memory.pop(current_label)
|
||||
|
||||
|
||||
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
|
||||
class BasicBlockMemory(Memory):
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
LegacyLettaMessage,
|
||||
@@ -217,7 +218,8 @@ def update_agent_memory(
|
||||
):
|
||||
"""
|
||||
Update the core memory of a specific agent.
|
||||
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
|
||||
This endpoint accepts new memory contents to update the core memory of the agent.
|
||||
This endpoint only supports modifying existing blocks; it does not support deleting/unlinking or creating/linking blocks.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
@@ -225,6 +227,66 @@ def update_agent_memory(
|
||||
return memory
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label")
|
||||
def update_agent_memory_label(
|
||||
agent_id: str,
|
||||
update_label: BlockLabelUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the label of a block in an agent's memory.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
memory = server.update_agent_memory_label(
|
||||
user_id=actor.id, agent_id=agent_id, current_block_label=update_label.current_label, new_block_label=update_label.new_label
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
@router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block")
|
||||
def add_agent_memory_block(
|
||||
agent_id: str,
|
||||
create_block: BlockCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Creates a memory block and links it to the agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Copied from POST /blocks
|
||||
block_req = Block(**create_block.model_dump())
|
||||
block = server.block_manager.create_or_update_block(actor=actor, block=block_req)
|
||||
|
||||
# Link the block to the agent
|
||||
updated_memory = server.link_block_to_agent_memory(user_id=actor.id, agent_id=agent_id, block_id=block.id)
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
@router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block")
|
||||
def remove_agent_memory_block(
|
||||
agent_id: str,
|
||||
# TODO should this be block_id, or the label?
|
||||
# I think label is OK since it's user-friendly + guaranteed to be unique within a Memory object
|
||||
block_label: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Unlink the block from the agent
|
||||
updated_memory = server.unlink_block_from_agent_memory(user_id=actor.id, agent_id=agent_id, block_label=block_label)
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
|
||||
def get_agent_recall_memory_summary(
|
||||
agent_id: str,
|
||||
|
||||
@@ -409,8 +409,10 @@ class SyncServer(Server):
|
||||
logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}")
|
||||
raise
|
||||
|
||||
def _get_or_load_agent(self, agent_id: str) -> Agent:
|
||||
def _get_or_load_agent(self, agent_id: str, caching: bool = True) -> Agent:
|
||||
"""Check if the agent is in-memory, then load"""
|
||||
|
||||
# Gets the agent state
|
||||
agent_state = self.ms.get_agent(agent_id=agent_id)
|
||||
if not agent_state:
|
||||
raise ValueError(f"Agent does not exist")
|
||||
@@ -418,11 +420,24 @@ class SyncServer(Server):
|
||||
actor = self.user_manager.get_user_by_id(user_id)
|
||||
|
||||
logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}")
|
||||
# TODO: consider disabling loading cached agents due to potential concurrency issues
|
||||
letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
||||
if not letta_agent:
|
||||
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
|
||||
if caching:
|
||||
# TODO: consider disabling loading cached agents due to potential concurrency issues
|
||||
letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
||||
if not letta_agent:
|
||||
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
|
||||
letta_agent = self._load_agent(agent_id=agent_id, actor=actor)
|
||||
else:
|
||||
# This breaks unit tests in test_local_client.py
|
||||
letta_agent = self._load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
||||
# if not letta_agent:
|
||||
# logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
|
||||
|
||||
# NOTE: no longer caching, always forcing a lot from the database
|
||||
# Loads the agent objects
|
||||
# letta_agent = self._load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
return letta_agent
|
||||
|
||||
def _step(
|
||||
@@ -1441,6 +1456,7 @@ class SyncServer(Server):
|
||||
# If we modified the memory contents, we need to rebuild the memory block inside the system message
|
||||
if modified:
|
||||
letta_agent.rebuild_memory()
|
||||
# letta_agent.rebuild_memory(force=True, ms=self.ms) # This breaks unit tests in test_local_client.py
|
||||
# save agent
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
@@ -1827,3 +1843,89 @@ class SyncServer(Server):
|
||||
# Get the current message
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
return letta_agent.get_context_window()
|
||||
|
||||
def update_agent_memory_label(self, user_id: str, agent_id: str, current_block_label: str, new_block_label: str) -> Memory:
|
||||
"""Update the label of a block in an agent's memory"""
|
||||
|
||||
# Get the user
|
||||
user = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
# Link a block to an agent's memory
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
letta_agent.memory.update_block_label(current_label=current_block_label, new_label=new_block_label)
|
||||
assert new_block_label in letta_agent.memory.list_block_labels()
|
||||
self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(new_block_label), actor=user)
|
||||
|
||||
# check that the block was updated
|
||||
updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(new_block_label).id, actor=user)
|
||||
|
||||
# Recompile the agent memory
|
||||
letta_agent.rebuild_memory(force=True, ms=self.ms)
|
||||
|
||||
# save agent
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
updated_agent = self.ms.get_agent(agent_id=agent_id)
|
||||
if updated_agent is None:
|
||||
raise ValueError(f"Agent with id {agent_id} not found after linking block")
|
||||
assert new_block_label in updated_agent.memory.list_block_labels()
|
||||
assert current_block_label not in updated_agent.memory.list_block_labels()
|
||||
return updated_agent.memory
|
||||
|
||||
def link_block_to_agent_memory(self, user_id: str, agent_id: str, block_id: str) -> Memory:
|
||||
"""Link a block to an agent's memory"""
|
||||
|
||||
# Get the user
|
||||
user = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
# Get the block first
|
||||
block = self.block_manager.get_block_by_id(block_id=block_id, actor=user)
|
||||
if block is None:
|
||||
raise ValueError(f"Block with id {block_id} not found")
|
||||
|
||||
# Link a block to an agent's memory
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
letta_agent.memory.link_block(block=block)
|
||||
assert block.label in letta_agent.memory.list_block_labels()
|
||||
|
||||
# Recompile the agent memory
|
||||
letta_agent.rebuild_memory(force=True, ms=self.ms)
|
||||
|
||||
# save agent
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
updated_agent = self.ms.get_agent(agent_id=agent_id)
|
||||
if updated_agent is None:
|
||||
raise ValueError(f"Agent with id {agent_id} not found after linking block")
|
||||
assert block.label in updated_agent.memory.list_block_labels()
|
||||
|
||||
return updated_agent.memory
|
||||
|
||||
def unlink_block_from_agent_memory(self, user_id: str, agent_id: str, block_label: str, delete_if_no_ref: bool = True) -> Memory:
|
||||
"""Unlink a block from an agent's memory. If the block is not linked to any agent, delete it."""
|
||||
|
||||
# Get the user
|
||||
user = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
# Link a block to an agent's memory
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
unlinked_block = letta_agent.memory.unlink_block(block_label=block_label)
|
||||
assert unlinked_block.label not in letta_agent.memory.list_block_labels()
|
||||
|
||||
# Check if the block is linked to any other agent
|
||||
# TODO needs reference counting GC to handle loose blocks
|
||||
# block = self.block_manager.get_block_by_id(block_id=unlinked_block.id, actor=user)
|
||||
# if block is None:
|
||||
# raise ValueError(f"Block with id {block_id} not found")
|
||||
|
||||
# Recompile the agent memory
|
||||
letta_agent.rebuild_memory(force=True, ms=self.ms)
|
||||
|
||||
# save agent
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
updated_agent = self.ms.get_agent(agent_id=agent_id)
|
||||
if updated_agent is None:
|
||||
raise ValueError(f"Agent with id {agent_id} not found after linking block")
|
||||
assert unlinked_block.label not in updated_agent.memory.list_block_labels()
|
||||
return updated_agent.memory
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.orm import FileMetadata, Source
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import BlockCreate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
@@ -32,7 +33,7 @@ from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_utc_time
|
||||
from letta.utils import create_random_username, get_utc_time
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
@@ -730,3 +731,82 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a
|
||||
# Verify all tags are removed
|
||||
final_tags = client.get_agent(agent_id=agent.id).tags
|
||||
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
|
||||
|
||||
|
||||
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can update the label of a block in an agent's memory"""
|
||||
|
||||
agent = client.create_agent(name=create_random_username())
|
||||
|
||||
try:
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
example_label = current_labels[0]
|
||||
example_new_label = "example_new_label"
|
||||
assert example_new_label not in current_labels
|
||||
|
||||
client.update_agent_memory_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_label in updated_agent.memory.list_block_labels()
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can add and remove a block from an agent's memory"""
|
||||
|
||||
agent = client.create_agent(name=create_random_username())
|
||||
|
||||
try:
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
example_new_label = "example_new_label"
|
||||
example_new_value = "example value"
|
||||
assert example_new_label not in current_labels
|
||||
|
||||
# Link a new memory block
|
||||
client.add_agent_memory_block(
|
||||
agent_id=agent.id,
|
||||
create_block=BlockCreate(
|
||||
label=example_new_label,
|
||||
value=example_new_value,
|
||||
limit=1000,
|
||||
),
|
||||
)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_label in updated_agent.memory.list_block_labels()
|
||||
|
||||
# Now unlink the block
|
||||
client.remove_agent_memory_block(agent_id=agent.id, block_label=example_new_label)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_label not in updated_agent.memory.list_block_labels()
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# """Test that the token limit is enforced for the core memory blocks"""
|
||||
|
||||
# # Create an agent
|
||||
# new_agent = client.create_agent(
|
||||
# name="test-core-memory-token-limits",
|
||||
# tools=BASE_TOOLS,
|
||||
# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=2000),
|
||||
# )
|
||||
|
||||
# try:
|
||||
# # Then intentionally set the limit to be extremely low
|
||||
# client.update_agent(
|
||||
# agent_id=new_agent.id,
|
||||
# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=100),
|
||||
# )
|
||||
|
||||
# # TODO we should probably not allow updating the core memory limit if
|
||||
|
||||
# # TODO in which case we should modify this test to actually to a proper token counter check
|
||||
|
||||
# finally:
|
||||
# client.delete_agent(new_agent.id)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
# Import the classes here, assuming the above definitions are in a module named memory_module
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.memory import ChatMemory, Memory
|
||||
|
||||
|
||||
@@ -105,3 +106,37 @@ def test_memory_jinja2_set_template(sample_memory: Memory):
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure)
|
||||
|
||||
|
||||
def test_link_unlink_block(sample_memory: Memory):
|
||||
"""Test linking and unlinking a block to the memory"""
|
||||
|
||||
# Link a new block
|
||||
|
||||
test_new_label = "test_new_label"
|
||||
test_new_value = "test_new_value"
|
||||
test_new_block = Block(label=test_new_label, value=test_new_value, limit=2000)
|
||||
|
||||
current_labels = sample_memory.list_block_labels()
|
||||
assert test_new_label not in current_labels
|
||||
|
||||
sample_memory.link_block(block=test_new_block)
|
||||
assert test_new_label in sample_memory.list_block_labels()
|
||||
assert sample_memory.get_block(test_new_label).value == test_new_value
|
||||
|
||||
# Unlink the block
|
||||
sample_memory.unlink_block(block_label=test_new_label)
|
||||
assert test_new_label not in sample_memory.list_block_labels()
|
||||
|
||||
|
||||
def test_update_block_label(sample_memory: Memory):
|
||||
"""Test updating the label of a block"""
|
||||
|
||||
test_new_label = "test_new_label"
|
||||
current_labels = sample_memory.list_block_labels()
|
||||
assert test_new_label not in current_labels
|
||||
test_old_label = current_labels[0]
|
||||
|
||||
sample_memory.update_block_label(current_label=test_old_label, new_label=test_new_label)
|
||||
assert test_new_label in sample_memory.list_block_labels()
|
||||
assert test_old_label not in sample_memory.list_block_labels()
|
||||
|
||||
Reference in New Issue
Block a user