@@ -1600,6 +1600,18 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to remove agent memory block: {response.text}")
|
||||
return Memory(**response.json())
|
||||
|
||||
def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory:
|
||||
|
||||
# @router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit")
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/memory/limit",
|
||||
headers=self.headers,
|
||||
json={"label": block_label, "limit": limit},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update agent memory limit: {response.text}")
|
||||
return Memory(**response.json())
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
"""
|
||||
@@ -2823,3 +2835,6 @@ class LocalClient(AbstractClient):
|
||||
|
||||
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
|
||||
return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label)
|
||||
|
||||
def update_agent_memory_limit(self, agent_id: str, block_label: str, limit: int) -> Memory:
|
||||
return self.server.update_agent_memory_limit(user_id=self.user_id, agent_id=agent_id, block_label=block_label, limit=limit)
|
||||
|
||||
@@ -124,6 +124,13 @@ class BlockUpdate(BaseBlock):
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
||||
|
||||
class BlockLimitUpdate(BaseModel):
|
||||
"""Update the limit of a block"""
|
||||
|
||||
label: str = Field(..., description="Label of the block.")
|
||||
limit: int = Field(..., description="New limit of the block.")
|
||||
|
||||
|
||||
class UpdatePersona(BlockUpdate):
|
||||
"""Update a persona block"""
|
||||
|
||||
|
||||
@@ -187,6 +187,19 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
# Then swap the block to the new label
|
||||
self.memory[new_label] = self.memory.pop(current_label)
|
||||
|
||||
def update_block_limit(self, label: str, limit: int):
|
||||
"""Update the limit of a block"""
|
||||
if label not in self.memory:
|
||||
raise ValueError(f"Block with label {label} does not exist")
|
||||
if not isinstance(limit, int):
|
||||
raise ValueError(f"Provided limit must be an integer")
|
||||
|
||||
# Check to make sure the new limit is greater than the current length of the block
|
||||
if len(self.memory[label].value) > limit:
|
||||
raise ValueError(f"New limit {limit} is less than the current length of the block {len(self.memory[label].value)}")
|
||||
|
||||
self.memory[label].limit = limit
|
||||
|
||||
|
||||
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
|
||||
class BasicBlockMemory(Memory):
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate
|
||||
from letta.schemas.block import Block, BlockCreate, BlockLabelUpdate, BlockLimitUpdate
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
LegacyLettaMessage,
|
||||
@@ -218,6 +218,7 @@ def update_agent_memory(
|
||||
):
|
||||
"""
|
||||
Update the core memory of a specific agent.
|
||||
This endpoint accepts new memory contents (labels as keys, and values as values) and updates the core memory of the agent identified by the user ID and agent ID.
|
||||
This endpoint accepts new memory contents to update the core memory of the agent.
|
||||
This endpoint only supports modifying existing blocks; it does not support deleting/unlinking or creating/linking blocks.
|
||||
"""
|
||||
@@ -287,6 +288,27 @@ def remove_agent_memory_block(
|
||||
return updated_memory
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/memory/limit", response_model=Memory, operation_id="update_agent_memory_limit")
|
||||
def update_agent_memory_limit(
|
||||
agent_id: str,
|
||||
update_label: BlockLimitUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the limit of a block in an agent's memory.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
memory = server.update_agent_memory_limit(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
block_label=update_label.label,
|
||||
limit=update_label.limit,
|
||||
)
|
||||
return memory
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
|
||||
def get_agent_recall_memory_summary(
|
||||
agent_id: str,
|
||||
|
||||
@@ -1929,3 +1929,33 @@ class SyncServer(Server):
|
||||
raise ValueError(f"Agent with id {agent_id} not found after linking block")
|
||||
assert unlinked_block.label not in updated_agent.memory.list_block_labels()
|
||||
return updated_agent.memory
|
||||
|
||||
def update_agent_memory_limit(self, user_id: str, agent_id: str, block_label: str, limit: int) -> Memory:
|
||||
"""Update the limit of a block in an agent's memory"""
|
||||
|
||||
# Get the user
|
||||
user = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
# Link a block to an agent's memory
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
letta_agent.memory.update_block_limit(label=block_label, limit=limit)
|
||||
assert block_label in letta_agent.memory.list_block_labels()
|
||||
|
||||
# write out the update the database
|
||||
self.block_manager.create_or_update_block(block=letta_agent.memory.get_block(block_label), actor=user)
|
||||
|
||||
# check that the block was updated
|
||||
updated_block = self.block_manager.get_block_by_id(block_id=letta_agent.memory.get_block(block_label).id, actor=user)
|
||||
assert updated_block and updated_block.limit == limit
|
||||
|
||||
# Recompile the agent memory
|
||||
letta_agent.rebuild_memory(force=True, ms=self.ms)
|
||||
|
||||
# save agent
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
updated_agent = self.ms.get_agent(agent_id=agent_id)
|
||||
if updated_agent is None:
|
||||
raise ValueError(f"Agent with id {agent_id} not found after linking block")
|
||||
assert updated_agent.memory.get_block(label=block_label).limit == limit
|
||||
return updated_agent.memory
|
||||
|
||||
@@ -810,3 +810,34 @@ def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], a
|
||||
|
||||
# finally:
|
||||
# client.delete_agent(new_agent.id)
|
||||
|
||||
|
||||
def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can update the limit of a block in an agent's memory"""
|
||||
|
||||
agent = client.create_agent(name=create_random_username())
|
||||
|
||||
try:
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
example_label = current_labels[0]
|
||||
example_new_limit = 1
|
||||
current_block = agent.memory.get_block(label=example_label)
|
||||
current_block_length = len(current_block.value)
|
||||
|
||||
assert example_new_limit != agent.memory.get_block(label=example_label).limit
|
||||
assert example_new_limit < current_block_length
|
||||
|
||||
# We expect this to throw a value error
|
||||
with pytest.raises(ValueError):
|
||||
client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit)
|
||||
|
||||
# Now try the same thing with a higher limit
|
||||
example_new_limit = current_block_length + 10000
|
||||
assert example_new_limit > current_block_length
|
||||
client.update_agent_memory_limit(agent_id=agent.id, block_label=example_label, limit=example_new_limit)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
@@ -140,3 +140,16 @@ def test_update_block_label(sample_memory: Memory):
|
||||
sample_memory.update_block_label(current_label=test_old_label, new_label=test_new_label)
|
||||
assert test_new_label in sample_memory.list_block_labels()
|
||||
assert test_old_label not in sample_memory.list_block_labels()
|
||||
|
||||
|
||||
def test_update_block_limit(sample_memory: Memory):
|
||||
"""Test updating the limit of a block"""
|
||||
|
||||
test_new_limit = 1000
|
||||
current_labels = sample_memory.list_block_labels()
|
||||
test_old_label = current_labels[0]
|
||||
|
||||
assert sample_memory.get_block(label=test_old_label).limit != test_new_limit
|
||||
|
||||
sample_memory.update_block_limit(label=test_old_label, limit=test_new_limit)
|
||||
assert sample_memory.get_block(label=test_old_label).limit == test_new_limit
|
||||
|
||||
Reference in New Issue
Block a user