feat: routes for adding/linking new memory blocks to agents + unlinking blocks from agents (#2083)

This commit is contained in:
Charles Packer
2024-11-21 20:08:47 -08:00
committed by GitHub
parent 06744c9193
commit 507a60f71c
7 changed files with 364 additions and 8 deletions

View File

@@ -1565,6 +1565,41 @@ class RESTClient(AbstractClient):
# Parse and return the deleted organization
return Organization(**response.json())
def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory:
# @router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label")
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/label",
headers=self.headers,
json={"current_label": current_label, "new_label": new_label},
)
if response.status_code != 200:
raise ValueError(f"Failed to update agent memory label: {response.text}")
return Memory(**response.json())
def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory:
# @router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block")
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block",
headers=self.headers,
json=create_block.model_dump(),
)
if response.status_code != 200:
raise ValueError(f"Failed to add agent memory block: {response.text}")
return Memory(**response.json())
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
# @router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block")
response = requests.delete(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/block/{block_label}",
headers=self.headers,
)
if response.status_code != 200:
raise ValueError(f"Failed to remove agent memory block: {response.text}")
return Memory(**response.json())
class LocalClient(AbstractClient):
"""
@@ -2773,3 +2808,18 @@ class LocalClient(AbstractClient):
def delete_org(self, org_id: str) -> Organization:
return self.server.organization_manager.delete_organization_by_id(org_id=org_id)
def update_agent_memory_label(self, agent_id: str, current_label: str, new_label: str) -> Memory:
return self.server.update_agent_memory_label(
user_id=self.user_id, agent_id=agent_id, current_block_label=current_label, new_block_label=new_label
)
def add_agent_memory_block(self, agent_id: str, create_block: BlockCreate) -> Memory:
block_req = Block(**create_block.model_dump())
block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req)
# Link the block to the agent
updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id)
return updated_memory
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label)

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, model_validator
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from letta.schemas.letta_base import LettaBase
@@ -95,6 +95,13 @@ class BlockCreate(BaseBlock):
label: str = Field(..., description="Label of the block.")
class BlockLabelUpdate(BaseModel):
"""Update the label of a block"""
current_label: str = Field(..., description="Current label of the block.")
new_label: str = Field(..., description="New label of the block.")
class CreatePersona(BlockCreate):
"""Create a persona block"""

View File

@@ -158,6 +158,13 @@ class Memory(BaseModel, validate_assignment=True):
self.memory[block.label] = block
def unlink_block(self, block_label: str) -> Block:
"""Unlink a block from the memory object"""
if block_label not in self.memory:
raise ValueError(f"Block with label {block_label} does not exist")
return self.memory.pop(block_label)
def update_block_value(self, label: str, value: str):
"""Update the value of a block"""
if label not in self.memory:
@@ -167,6 +174,19 @@ class Memory(BaseModel, validate_assignment=True):
self.memory[label].value = value
def update_block_label(self, current_label: str, new_label: str):
"""Update the label of a block"""
if current_label not in self.memory:
raise ValueError(f"Block with label {current_label} does not exist")
if not isinstance(new_label, str):
raise ValueError(f"Provided new label must be a string")
# First change the label of the block
self.memory[current_label].label = new_label
# Then swap the block to the new label
self.memory[new_label] = self.memory.pop(current_label)
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
class BasicBlockMemory(Memory):

View File

@@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
LegacyLettaMessage,
@@ -217,7 +218,8 @@ def update_agent_memory(
):
"""
Update the core memory of a specific agent.
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
This endpoint accepts new memory contents to update the core memory of the agent.
This endpoint only supports modifying existing blocks; it does not support deleting/unlinking or creating/linking blocks.
"""
actor = server.get_user_or_default(user_id=user_id)
@@ -225,6 +227,66 @@ def update_agent_memory(
return memory
@router.patch("/{agent_id}/memory/label", response_model=Memory, operation_id="update_agent_memory_label")
def update_agent_memory_label(
agent_id: str,
update_label: BlockLabelUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the label of a block in an agent's memory.
"""
actor = server.get_user_or_default(user_id=user_id)
memory = server.update_agent_memory_label(
user_id=actor.id, agent_id=agent_id, current_block_label=update_label.current_label, new_block_label=update_label.new_label
)
return memory
@router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block")
def add_agent_memory_block(
agent_id: str,
create_block: BlockCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Creates a memory block and links it to the agent.
"""
actor = server.get_user_or_default(user_id=user_id)
# Copied from POST /blocks
block_req = Block(**create_block.model_dump())
block = server.block_manager.create_or_update_block(actor=actor, block=block_req)
# Link the block to the agent
updated_memory = server.link_block_to_agent_memory(user_id=actor.id, agent_id=agent_id, block_id=block.id)
return updated_memory
@router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block")
def remove_agent_memory_block(
agent_id: str,
# TODO should this be block_id, or the label?
# I think label is OK since it's user-friendly + guaranteed to be unique within a Memory object
block_label: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
"""
actor = server.get_user_or_default(user_id=user_id)
# Unlink the block from the agent
updated_memory = server.unlink_block_from_agent_memory(user_id=actor.id, agent_id=agent_id, block_label=block_label)
return updated_memory
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
def get_agent_recall_memory_summary(
agent_id: str,

View File

@@ -409,8 +409,10 @@ class SyncServer(Server):
logger.exception(f"Error occurred while trying to get agent {agent_id}:\n{e}")
raise
def _get_or_load_agent(self, agent_id: str) -> Agent:
def _get_or_load_agent(self, agent_id: str, caching: bool = True) -> Agent:
"""Check if the agent is in-memory, then load"""
# Gets the agent state
agent_state = self.ms.get_agent(agent_id=agent_id)
if not agent_state:
raise ValueError(f"Agent does not exist")
@@ -418,11 +420,24 @@ class SyncServer(Server):
actor = self.user_manager.get_user_by_id(user_id)
logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}")
# TODO: consider disabling loading cached agents due to potential concurrency issues
letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
if not letta_agent:
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
if caching:
# TODO: consider disabling loading cached agents due to potential concurrency issues
letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
if not letta_agent:
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
letta_agent = self._load_agent(agent_id=agent_id, actor=actor)
else:
# This breaks unit tests in test_local_client.py
letta_agent = self._load_agent(agent_id=agent_id, actor=actor)
# letta_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
# if not letta_agent:
# logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
# NOTE: no longer caching, always forcing a lot from the database
# Loads the agent objects
# letta_agent = self._load_agent(agent_id=agent_id, actor=actor)
return letta_agent
def _step(
@@ -1441,6 +1456,7 @@ class SyncServer(Server):
# If we modified the memory contents, we need to rebuild the memory block inside the system message
if modified:
letta_agent.rebuild_memory()
# letta_agent.rebuild_memory(force=True, ms=self.ms) # This breaks unit tests in test_local_client.py
# save agent
save_agent(letta_agent, self.ms)
@@ -1827,3 +1843,89 @@ class SyncServer(Server):
# Get the current message
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.get_context_window()
def update_agent_memory_label(self, user_id: str, agent_id: str, current_block_label: str, new_block_label: str) -> Memory:
"""Update the label of a block in an agent's memory"""
# Get the user
user = self.user_manager.get_user_by_id(user_id=user_id)
# Link a block to an agent's memory
letta_agent = self._get_or_load_agent(agent_id=agent_id)
letta_agent.memory.update_block_label(current_label=current_block_label, new_label=new_block_label)
assert new_block_label in letta_agent.memory.list_block_labels()
self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(new_block_label), actor=user)
# check that the block was updated
updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(new_block_label).id, actor=user)
# Recompile the agent memory
letta_agent.rebuild_memory(force=True, ms=self.ms)
# save agent
save_agent(letta_agent, self.ms)
updated_agent = self.ms.get_agent(agent_id=agent_id)
if updated_agent is None:
raise ValueError(f"Agent with id {agent_id} not found after linking block")
assert new_block_label in updated_agent.memory.list_block_labels()
assert current_block_label not in updated_agent.memory.list_block_labels()
return updated_agent.memory
def link_block_to_agent_memory(self, user_id: str, agent_id: str, block_id: str) -> Memory:
"""Link a block to an agent's memory"""
# Get the user
user = self.user_manager.get_user_by_id(user_id=user_id)
# Get the block first
block = self.block_manager.get_block_by_id(block_id=block_id, actor=user)
if block is None:
raise ValueError(f"Block with id {block_id} not found")
# Link a block to an agent's memory
letta_agent = self._get_or_load_agent(agent_id=agent_id)
letta_agent.memory.link_block(block=block)
assert block.label in letta_agent.memory.list_block_labels()
# Recompile the agent memory
letta_agent.rebuild_memory(force=True, ms=self.ms)
# save agent
save_agent(letta_agent, self.ms)
updated_agent = self.ms.get_agent(agent_id=agent_id)
if updated_agent is None:
raise ValueError(f"Agent with id {agent_id} not found after linking block")
assert block.label in updated_agent.memory.list_block_labels()
return updated_agent.memory
def unlink_block_from_agent_memory(self, user_id: str, agent_id: str, block_label: str, delete_if_no_ref: bool = True) -> Memory:
"""Unlink a block from an agent's memory. If the block is not linked to any agent, delete it."""
# Get the user
user = self.user_manager.get_user_by_id(user_id=user_id)
# Link a block to an agent's memory
letta_agent = self._get_or_load_agent(agent_id=agent_id)
unlinked_block = letta_agent.memory.unlink_block(block_label=block_label)
assert unlinked_block.label not in letta_agent.memory.list_block_labels()
# Check if the block is linked to any other agent
# TODO needs reference counting GC to handle loose blocks
# block = self.block_manager.get_block_by_id(block_id=unlinked_block.id, actor=user)
# if block is None:
# raise ValueError(f"Block with id {block_id} not found")
# Recompile the agent memory
letta_agent.rebuild_memory(force=True, ms=self.ms)
# save agent
save_agent(letta_agent, self.ms)
updated_agent = self.ms.get_agent(agent_id=agent_id)
if updated_agent is None:
raise ValueError(f"Agent with id {agent_id} not found after linking block")
assert unlinked_block.label not in updated_agent.memory.list_block_labels()
return updated_agent.memory

View File

@@ -15,6 +15,7 @@ from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.orm import FileMetadata, Source
from letta.schemas.agent import AgentState
from letta.schemas.block import BlockCreate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import (
@@ -32,7 +33,7 @@ from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
from letta.settings import model_settings
from letta.utils import get_utc_time
from letta.utils import create_random_username, get_utc_time
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -730,3 +731,82 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a
# Verify all tags are removed
final_tags = client.get_agent(agent_id=agent.id).tags
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can update the label of a block in an agent's memory"""
agent = client.create_agent(name=create_random_username())
try:
current_labels = agent.memory.list_block_labels()
example_label = current_labels[0]
example_new_label = "example_new_label"
assert example_new_label not in current_labels
client.update_agent_memory_label(agent_id=agent.id, current_label=example_label, new_label=example_new_label)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_label in updated_agent.memory.list_block_labels()
finally:
client.delete_agent(agent.id)
def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can add and remove a block from an agent's memory"""
agent = client.create_agent(name=create_random_username())
try:
current_labels = agent.memory.list_block_labels()
example_new_label = "example_new_label"
example_new_value = "example value"
assert example_new_label not in current_labels
# Link a new memory block
client.add_agent_memory_block(
agent_id=agent.id,
create_block=BlockCreate(
label=example_new_label,
value=example_new_value,
limit=1000,
),
)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_label in updated_agent.memory.list_block_labels()
# Now unlink the block
client.remove_agent_memory_block(agent_id=agent.id, block_label=example_new_label)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_label not in updated_agent.memory.list_block_labels()
finally:
client.delete_agent(agent.id)
# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState):
# """Test that the token limit is enforced for the core memory blocks"""
# # Create an agent
# new_agent = client.create_agent(
# name="test-core-memory-token-limits",
# tools=BASE_TOOLS,
# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=2000),
# )
# try:
# # Then intentionally set the limit to be extremely low
# client.update_agent(
# agent_id=new_agent.id,
# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=100),
# )
# # TODO we should probably not allow updating the core memory limit if
# # TODO in which case we should modify this test to actually to a proper token counter check
# finally:
# client.delete_agent(new_agent.id)

View File

@@ -1,6 +1,7 @@
import pytest
# Import the classes here, assuming the above definitions are in a module named memory_module
from letta.schemas.block import Block
from letta.schemas.memory import ChatMemory, Memory
@@ -105,3 +106,37 @@ def test_memory_jinja2_set_template(sample_memory: Memory):
)
with pytest.raises(ValueError):
sample_memory.set_prompt_template(prompt_template=template_bad_memory_structure)
def test_link_unlink_block(sample_memory: Memory):
"""Test linking and unlinking a block to the memory"""
# Link a new block
test_new_label = "test_new_label"
test_new_value = "test_new_value"
test_new_block = Block(label=test_new_label, value=test_new_value, limit=2000)
current_labels = sample_memory.list_block_labels()
assert test_new_label not in current_labels
sample_memory.link_block(block=test_new_block)
assert test_new_label in sample_memory.list_block_labels()
assert sample_memory.get_block(test_new_label).value == test_new_value
# Unlink the block
sample_memory.unlink_block(block_label=test_new_label)
assert test_new_label not in sample_memory.list_block_labels()
def test_update_block_label(sample_memory: Memory):
"""Test updating the label of a block"""
test_new_label = "test_new_label"
current_labels = sample_memory.list_block_labels()
assert test_new_label not in current_labels
test_old_label = current_labels[0]
sample_memory.update_block_label(current_label=test_old_label, new_label=test_new_label)
assert test_new_label in sample_memory.list_block_labels()
assert test_old_label not in sample_memory.list_block_labels()