fix: move all attach detach to be under agents (#723)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2025-01-22 12:16:31 -08:00
committed by GitHub
parent 50de3cb4b7
commit c9e18e66f1
8 changed files with 418 additions and 350 deletions

View File

@@ -92,19 +92,19 @@ class AbstractClient(object):
):
raise NotImplementedError
def get_tools_from_agent(self, agent_id: str):
def get_tools_from_agent(self, agent_id: str) -> List[Tool]:
raise NotImplementedError
def add_tool_to_agent(self, agent_id: str, tool_id: str):
def attach_tool(self, agent_id: str, tool_id: str) -> AgentState:
raise NotImplementedError
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
def detach_tool(self, agent_id: str, tool_id: str) -> AgentState:
raise NotImplementedError
def rename_agent(self, agent_id: str, new_name: str):
def rename_agent(self, agent_id: str, new_name: str) -> AgentState:
raise NotImplementedError
def delete_agent(self, agent_id: str):
def delete_agent(self, agent_id: str) -> None:
raise NotImplementedError
def get_agent(self, agent_id: str) -> AgentState:
@@ -218,6 +218,18 @@ class AbstractClient(object):
def get_tool_id(self, name: str) -> Optional[str]:
raise NotImplementedError
def list_attached_tools(self, agent_id: str) -> List[Tool]:
"""
List all tools attached to an agent.
Args:
agent_id (str): ID of the agent
Returns:
List[Tool]: A list of attached tools
"""
raise NotImplementedError
def upsert_base_tools(self) -> List[Tool]:
raise NotImplementedError
@@ -242,10 +254,10 @@ class AbstractClient(object):
def get_source_id(self, source_name: str) -> str:
raise NotImplementedError
def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
def attach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
raise NotImplementedError
def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
def detach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
raise NotImplementedError
def list_sources(self) -> List[Source]:
@@ -397,6 +409,26 @@ class AbstractClient(object):
"""
raise NotImplementedError
def attach_block(self, agent_id: str, block_id: str) -> AgentState:
"""
Attach a block to an agent.
Args:
agent_id (str): ID of the agent
block_id (str): ID of the block to attach
"""
raise NotImplementedError
def detach_block(self, agent_id: str, block_id: str) -> AgentState:
"""
Detach a block from an agent.
Args:
agent_id (str): ID of the agent
block_id (str): ID of the block to detach
"""
raise NotImplementedError
class RESTClient(AbstractClient):
"""
@@ -620,7 +652,7 @@ class RESTClient(AbstractClient):
embedding_config: Optional[EmbeddingConfig] = None,
message_ids: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
):
) -> AgentState:
"""
Update an existing agent
@@ -670,7 +702,7 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to get tools from agents: {response.text}")
return [Tool(**tool) for tool in response.json()]
def add_tool_to_agent(self, agent_id: str, tool_id: str):
def attach_tool(self, agent_id: str, tool_id: str) -> AgentState:
"""
Add tool to an existing agent
@@ -681,12 +713,12 @@ class RESTClient(AbstractClient):
Returns:
agent_state (AgentState): State of the updated agent
"""
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/add-tool/{tool_id}", headers=self.headers)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools/attach/{tool_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update agent: {response.text}")
return AgentState(**response.json())
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
def detach_tool(self, agent_id: str, tool_id: str) -> AgentState:
"""
Removes tools from an existing agent
@@ -698,12 +730,12 @@ class RESTClient(AbstractClient):
agent_state (AgentState): State of the updated agent
"""
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/remove-tool/{tool_id}", headers=self.headers)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools/detach/{tool_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update agent: {response.text}")
return AgentState(**response.json())
def rename_agent(self, agent_id: str, new_name: str):
def rename_agent(self, agent_id: str, new_name: str) -> AgentState:
"""
Rename an agent
@@ -711,10 +743,12 @@ class RESTClient(AbstractClient):
agent_id (str): ID of the agent
new_name (str): New name for the agent
Returns:
agent_state (AgentState): State of the updated agent
"""
return self.update_agent(agent_id, name=new_name)
def delete_agent(self, agent_id: str):
def delete_agent(self, agent_id: str) -> None:
"""
Delete an agent
@@ -1425,7 +1459,7 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to update source: {response.text}")
return Source(**response.json())
def attach_source_to_agent(self, source_id: str, agent_id: str):
def attach_source(self, source_id: str, agent_id: str) -> AgentState:
"""
Attach a source to an agent
@@ -1435,15 +1469,20 @@ class RESTClient(AbstractClient):
source_name (str): Name of the source
"""
params = {"agent_id": agent_id}
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/attach", params=params, headers=self.headers)
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/sources/attach/{source_id}", params=params, headers=self.headers
)
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
return AgentState(**response.json())
def detach_source(self, source_id: str, agent_id: str):
def detach_source(self, source_id: str, agent_id: str) -> AgentState:
"""Detach a source from an agent"""
params = {"agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/detach", params=params, headers=self.headers)
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/sources/detach/{source_id}", params=params, headers=self.headers
)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return Source(**response.json())
return AgentState(**response.json())
# tools
@@ -1466,6 +1505,21 @@ class RESTClient(AbstractClient):
return None
return tools[0].id
def list_attached_tools(self, agent_id: str) -> List[Tool]:
"""
List all tools attached to an agent.
Args:
agent_id (str): ID of the agent
Returns:
List[Tool]: A list of attached tools
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list attached tools: {response.text}")
return [Tool(**tool) for tool in response.json()]
def upsert_base_tools(self) -> List[Tool]:
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers)
if response.status_code != 200:
@@ -1835,66 +1889,36 @@ class RESTClient(AbstractClient):
block = self.get_agent_memory_block(agent_id, current_label)
return self.update_block(block.id, label=new_label)
# TODO: remove this
def add_agent_memory_block(self, agent_id: str, create_block: CreateBlock) -> Memory:
def attach_block(self, agent_id: str, block_id: str) -> AgentState:
"""
Create and link a memory block to an agent's core memory
Attach a block to an agent.
Args:
agent_id (str): The agent ID
create_block (CreateBlock): The block to create
Returns:
memory (Memory): The updated memory
agent_id (str): ID of the agent
block_id (str): ID of the block to attach
"""
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks",
headers=self.headers,
json=create_block.model_dump(),
)
if response.status_code != 200:
raise ValueError(f"Failed to add agent memory block: {response.text}")
return Memory(**response.json())
def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory:
"""
Link a block to an agent's core memory
Args:
agent_id (str): The agent ID
block_id (str): The block ID
Returns:
memory (Memory): The updated memory
"""
params = {"agent_id": agent_id}
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/blocks/{block_id}/attach",
params=params,
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks/attach/{block_id}",
headers=self.headers,
)
if response.status_code != 200:
raise ValueError(f"Failed to link agent memory block: {response.text}")
return Block(**response.json())
raise ValueError(f"Failed to attach block to agent: {response.text}")
return AgentState(**response.json())
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
def detach_block(self, agent_id: str, block_id: str) -> AgentState:
"""
Unlike a block from the agent's core memory
Detach a block from an agent.
Args:
agent_id (str): The agent ID
block_label (str): The block label
Returns:
memory (Memory): The updated memory
agent_id (str): ID of the agent
block_id (str): ID of the block to detach
"""
response = requests.delete(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks/{block_label}",
headers=self.headers,
response = requests.patch(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks/detach/{block_id}", headers=self.headers
)
if response.status_code != 200:
raise ValueError(f"Failed to remove agent memory block: {response.text}")
return Memory(**response.json())
raise ValueError(f"Failed to detach block from agent: {response.text}")
return AgentState(**response.json())
def list_agent_memory_blocks(self, agent_id: str) -> List[Block]:
"""
@@ -2381,7 +2405,7 @@ class LocalClient(AbstractClient):
Returns:
agent_state (AgentState): State of the updated agent
"""
# TODO: add the abilitty to reset linked block_ids
# TODO: add the ability to reset linked block_ids
self.interface.clear()
agent_state = self.server.agent_manager.update_agent(
agent_id,
@@ -2413,7 +2437,7 @@ class LocalClient(AbstractClient):
self.interface.clear()
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools
def add_tool_to_agent(self, agent_id: str, tool_id: str):
def attach_tool(self, agent_id: str, tool_id: str) -> AgentState:
"""
Add tool to an existing agent
@@ -2428,7 +2452,7 @@ class LocalClient(AbstractClient):
agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
return agent_state
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
def detach_tool(self, agent_id: str, tool_id: str) -> AgentState:
"""
Removes tools from an existing agent
@@ -2443,17 +2467,20 @@ class LocalClient(AbstractClient):
agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
return agent_state
def rename_agent(self, agent_id: str, new_name: str):
def rename_agent(self, agent_id: str, new_name: str) -> AgentState:
"""
Rename an agent
Args:
agent_id (str): ID of the agent
new_name (str): New name for the agent
"""
self.update_agent(agent_id, name=new_name)
def delete_agent(self, agent_id: str):
Returns:
agent_state (AgentState): State of the updated agent
"""
return self.update_agent(agent_id, name=new_name)
def delete_agent(self, agent_id: str) -> None:
"""
Delete an agent
@@ -3028,6 +3055,18 @@ class LocalClient(AbstractClient):
tool = self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user)
return tool.id if tool else None
def list_attached_tools(self, agent_id: str) -> List[Tool]:
"""
List all tools attached to an agent.
Args:
agent_id (str): ID of the agent
Returns:
List[Tool]: List of tools attached to the agent
"""
return self.server.agent_manager.list_attached_tools(agent_id=agent_id, actor=self.user)
def load_data(self, connector: DataConnector, source_name: str):
"""
Load data into a source
@@ -3061,14 +3100,14 @@ class LocalClient(AbstractClient):
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id, actor=self.user)
return job
def delete_file_from_source(self, source_id: str, file_id: str):
def delete_file_from_source(self, source_id: str, file_id: str) -> None:
self.server.source_manager.delete_file(file_id, actor=self.user)
def get_job(self, job_id: str):
return self.server.job_manager.get_job_by_id(job_id=job_id, actor=self.user)
def delete_job(self, job_id: str):
return self.server.job_manager.delete_job(job_id=job_id, actor=self.user)
return self.server.job_manager.delete_job_by_id(job_id=job_id, actor=self.user)
def list_jobs(self):
return self.server.job_manager.list_jobs(actor=self.user)
@@ -3127,7 +3166,7 @@ class LocalClient(AbstractClient):
"""
return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id
def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
def attach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
"""
Attach a source to an agent
@@ -3140,9 +3179,9 @@ class LocalClient(AbstractClient):
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
source_id = source.id
self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user)
return self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user)
def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
def detach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
"""
Detach a source from an agent by removing all `Passage` objects that were loaded from the source from archival memory.
Args:
@@ -3475,51 +3514,7 @@ class LocalClient(AbstractClient):
block = self.get_agent_memory_block(agent_id, current_label)
return self.update_block(block.id, label=new_label)
# TODO: remove this
def add_agent_memory_block(self, agent_id: str, create_block: CreateBlock) -> Memory:
"""
Create and link a memory block to an agent's core memory
Args:
agent_id (str): The agent ID
create_block (CreateBlock): The block to create
Returns:
memory (Memory): The updated memory
"""
block_req = Block(**create_block.model_dump())
block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req)
# Link the block to the agent
agent = self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=self.user)
return agent.memory
def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory:
"""
Link a block to an agent's core memory
Args:
agent_id (str): The agent ID
block_id (str): The block ID
Returns:
memory (Memory): The updated memory
"""
return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
"""
Unlike a block from the agent's core memory
Args:
agent_id (str): The agent ID
block_label (str): The block label
Returns:
memory (Memory): The updated memory
"""
return self.server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=self.user)
def list_agent_memory_blocks(self, agent_id: str) -> List[Block]:
def get_agent_memory_blocks(self, agent_id: str) -> List[Block]:
"""
Get all the blocks in the agent's core memory
@@ -3600,6 +3595,26 @@ class LocalClient(AbstractClient):
data["label"] = label
return self.server.block_manager.update_block(block_id, actor=self.user, block_update=BlockUpdate(**data))
def attach_block(self, agent_id: str, block_id: str) -> AgentState:
"""
Attach a block to an agent.
Args:
agent_id (str): ID of the agent
block_id (str): ID of the block to attach
"""
return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
def detach_block(self, agent_id: str, block_id: str) -> AgentState:
"""
Detach a block from an agent.
Args:
agent_id (str): ID of the agent
block_id (str): ID of the block to detach
"""
return self.server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
def get_run_messages(
self,
run_id: str,

View File

@@ -126,43 +126,63 @@ def get_tools_from_agent(
):
"""Get tools from an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).tools
return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor)
@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
def add_tool_to_agent(
@router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool_to_agent")
def attach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""Add tools to an existing agent"""
"""
Attach a tool to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
def remove_tool_from_agent(
@router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool_from_agent")
def detach_tool(
agent_id: str,
tool_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""Add tools to an existing agent"""
"""
Detach a tool from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
@router.patch("/{agent_id}/reset-messages", response_model=AgentState, operation_id="reset_messages")
def reset_messages(
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
def attach_source(
agent_id: str,
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""Resets the messages for an agent"""
"""
Attach a source to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
return server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
@router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent")
def detach_source(
agent_id: str,
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Detach a source from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
@@ -263,49 +283,6 @@ def list_agent_memory_blocks(
raise HTTPException(status_code=404, detail=str(e))
@router.post("/{agent_id}/core_memory/blocks", response_model=Memory, operation_id="add_agent_memory_block")
def add_agent_memory_block(
agent_id: str,
create_block: CreateBlock = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Creates a memory block and links it to the agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
# Copied from POST /blocks
# TODO: Should have block_manager accept only CreateBlock
# TODO: This will be possible once we move ID creation to the ORM
block_req = Block(**create_block.model_dump())
block = server.block_manager.create_or_update_block(actor=actor, block=block_req)
# Link the block to the agent
agent = server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=actor)
return agent.memory
@router.delete("/{agent_id}/core_memory/blocks/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block_by_label")
def remove_agent_memory_block(
agent_id: str,
# TODO should this be block_id, or the label?
# I think label is OK since it's user-friendly + guaranteed to be unique within a Memory object
block_label: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
# Unlink the block from the agent
agent = server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
return agent.memory
@router.patch("/{agent_id}/core_memory/blocks/{block_label}", response_model=Block, operation_id="update_agent_memory_block_by_label")
def update_agent_memory_block(
agent_id: str,
@@ -315,7 +292,7 @@ def update_agent_memory_block(
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
Updates a memory block of an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
@@ -328,6 +305,34 @@ def update_agent_memory_block(
return block
@router.patch("/{agent_id}/core_memory/blocks/attach/{block_id}", response_model=AgentState, operation_id="attach_block_to_agent")
def attach_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Attach a block to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
@router.patch("/{agent_id}/core_memory/blocks/detach/{block_id}", response_model=AgentState, operation_id="detach_block_from_agent")
def detach_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Detach a block from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
@router.get("/{agent_id}/archival_memory", response_model=List[Passage], operation_id="list_agent_archival_memory")
def get_agent_archival_memory(
agent_id: str,
@@ -610,3 +615,15 @@ async def send_message_async(
)
return run
@router.patch("/{agent_id}/reset-messages", response_model=AgentState, operation_id="reset_messages")
def reset_messages(
agent_id: str,
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Resets the messages for an agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.orm.errors import NoResultFound
from letta.schemas.block import Block, BlockUpdate, CreateBlock
@@ -73,41 +73,3 @@ def get_block(
return block
except NoResultFound:
raise HTTPException(status_code=404, detail="Block not found")
@router.patch("/{block_id}/attach", response_model=None, status_code=204, operation_id="link_agent_memory_block")
def link_agent_memory_block(
block_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Link a memory block to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
return Response(status_code=204)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.patch("/{block_id}/detach", response_model=None, status_code=204, operation_id="unlink_agent_memory_block")
def unlink_agent_memory_block(
block_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Unlink a memory block from an agent
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
return Response(status_code=204)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -111,36 +111,6 @@ def delete_source(
server.delete_source(source_id=source_id, actor=actor)
@router.post("/{source_id}/attach", response_model=Source, operation_id="attach_agent_to_source")
def attach_source_to_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Attach a data source to an existing agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=actor)
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
@router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source")
def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
) -> None:
"""
Detach a data source from an existing agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
def upload_file_to_source(
file: UploadFile,

View File

@@ -25,6 +25,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
from letta.services.block_manager import BlockManager
@@ -537,7 +538,7 @@ class AgentManager:
# Source Management
# ======================================================================================================================
@enforce_types
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a source to an agent.
@@ -567,6 +568,7 @@ class AgentManager:
# Commit the changes
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
@@ -588,7 +590,7 @@ class AgentManager:
return [source.to_pydantic() for source in agent.sources]
@enforce_types
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a source from an agent.
@@ -602,10 +604,17 @@ class AgentManager:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Remove the source from the relationship
agent.sources = [s for s in agent.sources if s.id != source_id]
remaining_sources = [s for s in agent.sources if s.id != source_id]
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
# Update the sources relationship
agent.sources = remaining_sources
# Commit the changes
agent.update(session, actor=actor)
return agent.to_pydantic()
# ======================================================================================================================
# Block management
@@ -1011,6 +1020,22 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
"""
List all tools attached to an agent.
Args:
agent_id: ID of the agent to list tools for.
actor: User performing the action.
Returns:
List[PydanticTool]: List of tools attached to the agent.
"""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return [tool.to_pydantic() for tool in agent.tools]
# ======================================================================================================================
# Tag Management
# ======================================================================================================================

View File

@@ -13,7 +13,6 @@ from sqlalchemy import delete
from letta import LocalClient, RESTClient, create_client
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.job import JobStatus
@@ -113,46 +112,6 @@ def clear_tables():
session.commit()
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
# _reset_config()
# create a block
block = client.create_block(label="human", value="username: sarah")
# create agents with shared block
from letta.schemas.block import Block
from letta.schemas.memory import BasicBlockMemory
# persona1_block = client.create_block(label="persona", value="you are agent 1")
# persona2_block = client.create_block(label="persona", value="you are agent 2")
# create agents
agent_state1 = client.create_agent(
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
)
agent_state2 = client.create_agent(
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
)
## attach shared block to both agents
# client.link_agent_memory_block(agent_state1.id, block.id)
# client.link_agent_memory_block(agent_state2.id, block.id)
# update memory
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
# check agent 2 memory
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
client.user_message(agent_id=agent_state2.id, message="whats my name?")
assert (
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
"""
Test sandbox config and environment variable functions for both LocalClient and RESTClient.
@@ -204,6 +163,11 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]
client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent tags
# --------------------------------------------------------------------------------------------------------------------
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]):
"""
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
@@ -306,6 +270,49 @@ def test_agent_tags(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent3.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent memory blocks
# --------------------------------------------------------------------------------------------------------------------
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
# _reset_config()
# create a block
block = client.create_block(label="human", value="username: sarah")
# create agents with shared block
from letta.schemas.block import Block
from letta.schemas.memory import BasicBlockMemory
# persona1_block = client.create_block(label="persona", value="you are agent 1")
# persona2_block = client.create_block(label="persona", value="you are agent 2")
# create agents
agent_state1 = client.create_agent(
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
)
agent_state2 = client.create_agent(
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
)
## attach shared block to both agents
# client.link_agent_memory_block(agent_state1.id, block.id)
# client.link_agent_memory_block(agent_state2.id, block.id)
# update memory
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
# check agent 2 memory
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
client.user_message(agent_id=agent_state2.id, message="whats my name?")
assert (
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can update the label of a block in an agent's memory"""
@@ -326,38 +333,32 @@ def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent
client.delete_agent(agent.id)
def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can add and remove a block from an agent's memory"""
agent = client.create_agent(name=create_random_username())
current_labels = agent.memory.list_block_labels()
example_new_label = current_labels[0] + "_v2"
example_new_value = "example value"
assert example_new_label not in current_labels
try:
current_labels = agent.memory.list_block_labels()
example_new_label = "example_new_label"
example_new_value = "example value"
assert example_new_label not in current_labels
# Link a new memory block
block = client.create_block(
label=example_new_label,
value=example_new_value,
limit=1000,
)
updated_agent = client.attach_block(
agent_id=agent.id,
block_id=block.id,
)
assert example_new_label in updated_agent.memory.list_block_labels()
# Link a new memory block
client.add_agent_memory_block(
agent_id=agent.id,
create_block=CreateBlock(
label=example_new_label,
value=example_new_value,
limit=1000,
),
)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_label in updated_agent.memory.list_block_labels()
# Now unlink the block
client.remove_agent_memory_block(agent_id=agent.id, block_label=example_new_label)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_label not in updated_agent.memory.list_block_labels()
finally:
client.delete_agent(agent.id)
# Now unlink the block
updated_agent = client.detach_block(
agent_id=agent.id,
block_id=block.id,
)
assert example_new_label not in updated_agent.memory.list_block_labels()
# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState):
@@ -413,24 +414,9 @@ def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent.id)
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"
messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
send_system_message_response = client.send_message(
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
)
assert send_system_message_response, "Sending message failed"
# --------------------------------------------------------------------------------------------------------------------
# Agent Tools
# --------------------------------------------------------------------------------------------------------------------
def test_function_return_limit(client: Union[LocalClient, RESTClient]):
"""Test to see if the function return limit works"""
@@ -503,6 +489,70 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent_id=agent.id)
def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can attach and detach a tool from an agent"""
try:
# Create a tool
def example_tool(x: int) -> int:
"""
This is an example tool.
Parameters:
x (int): The input value.
Returns:
int: The output value.
"""
return x * 2
tool = client.create_or_update_tool(func=example_tool, name="test_tool")
# Initially tool should not be attached
initial_tools = client.list_attached_tools(agent_id=agent.id)
assert tool.id not in [t.id for t in initial_tools]
# Attach tool
new_agent_state = client.attach_tool(agent_id=agent.id, tool_id=tool.id)
assert tool.id in [t.id for t in new_agent_state.tools]
# Verify tool is attached
updated_tools = client.list_attached_tools(agent_id=agent.id)
assert tool.id in [t.id for t in updated_tools]
# Detach tool
new_agent_state = client.detach_tool(agent_id=agent.id, tool_id=tool.id)
assert tool.id not in [t.id for t in new_agent_state.tools]
# Verify tool is detached
final_tools = client.list_attached_tools(agent_id=agent.id)
assert tool.id not in [t.id for t in final_tools]
finally:
client.delete_tool(tool.id)
# --------------------------------------------------------------------------------------------------------------------
# AgentMessages
# --------------------------------------------------------------------------------------------------------------------
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"
messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
send_system_message_response = client.send_message(
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
)
assert send_system_message_response, "Sending message failed"
@pytest.mark.asyncio
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
"""
@@ -580,9 +630,9 @@ def test_send_message_async(client: Union[LocalClient, RESTClient], agent: Agent
assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens
# ==========================================
# TESTS FOR AGENT LISTING
# ==========================================
# ----------------------------------------------------------------------------------------------------
# Agent listing
# ----------------------------------------------------------------------------------------------------
def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two):
@@ -678,3 +728,33 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]):
assert all(tool.id in tool_ids for tool in agent_tools)
client.delete_agent(agent_id=agent.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent sources
# --------------------------------------------------------------------------------------------------------------------
def test_attach_detach_agent_source(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can attach and detach a source from an agent"""
# Create a source
source = client.create_source(
name="test_source",
)
initial_sources = client.list_attached_sources(agent_id=agent.id)
assert source.id not in [s.id for s in initial_sources]
# Attach source
client.attach_source(agent_id=agent.id, source_id=source.id)
# Verify source is attached
final_sources = client.list_attached_sources(agent_id=agent.id)
assert source.id in [s.id for s in final_sources]
# Detach source
client.detach_source(agent_id=agent.id, source_id=source.id)
# Verify source is detached
final_sources = client.list_attached_sources(agent_id=agent.id)
assert source.id not in [s.id for s in final_sources]
client.delete_source(source.id)

View File

@@ -205,9 +205,9 @@ def test_archival_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTC
passages = client.get_archival_memory(agent.id)
assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}"
# get archival memory summary
archival_summary = client.get_archival_memory_summary(agent.id)
assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}"
# # get archival memory summary
# archival_summary = client.get_agent_archival_memory_summary(agent.id)
# assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}"
# delete archival memory
client.delete_archival_memory(agent.id, passage.id)
@@ -500,7 +500,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
assert len(archival_memories) == 0
# attach a source
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
client.attach_source(source_id=source.id, agent_id=agent.id)
# list attached sources
attached_sources = client.list_attached_sources(agent_id=agent.id)
@@ -521,8 +521,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# detach the source
assert len(client.get_archival_memory(agent_id=agent.id)) > 0, "No archival memory"
deleted_source = client.detach_source(source_id=source.id, agent_id=agent.id)
assert deleted_source.id == source.id
client.detach_source(source_id=source.id, agent_id=agent.id)
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}"
assert source.id not in [s.id for s in client.list_attached_sources(agent.id)]

View File

@@ -141,7 +141,7 @@ def test_agent_add_remove_tools(client: LocalClient, agent):
curr_num_tools = len(agent_state.tools)
# add both tools to agent in steps
agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=github_tool.id)
agent_state = client.attach_tool(agent_id=agent_state.id, tool_id=github_tool.id)
# confirm that both tools are in the agent state
# we could access it like agent_state.tools, but will use the client function instead
@@ -153,7 +153,7 @@ def test_agent_add_remove_tools(client: LocalClient, agent):
assert github_tool.name in curr_tool_names
# remove only the github tool
agent_state = client.remove_tool_from_agent(agent_id=agent_state.id, tool_id=github_tool.id)
agent_state = client.detach_tool(agent_id=agent_state.id, tool_id=github_tool.id)
# confirm that only one tool left
curr_tools = client.get_tools_from_agent(agent_state.id)