fix: move all attach detach to be under agents (#723)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -92,19 +92,19 @@ class AbstractClient(object):
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_tools_from_agent(self, agent_id: str):
|
||||
def get_tools_from_agent(self, agent_id: str) -> List[Tool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str):
|
||||
def attach_tool(self, agent_id: str, tool_id: str) -> AgentState:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
|
||||
def detach_tool(self, agent_id: str, tool_id: str) -> AgentState:
|
||||
raise NotImplementedError
|
||||
|
||||
def rename_agent(self, agent_id: str, new_name: str):
|
||||
def rename_agent(self, agent_id: str, new_name: str) -> AgentState:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_agent(self, agent_id: str):
|
||||
def delete_agent(self, agent_id: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_agent(self, agent_id: str) -> AgentState:
|
||||
@@ -218,6 +218,18 @@ class AbstractClient(object):
|
||||
def get_tool_id(self, name: str) -> Optional[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_attached_tools(self, agent_id: str) -> List[Tool]:
|
||||
"""
|
||||
List all tools attached to an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
|
||||
Returns:
|
||||
List[Tool]: A list of attached tools
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def upsert_base_tools(self) -> List[Tool]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -242,10 +254,10 @@ class AbstractClient(object):
|
||||
def get_source_id(self, source_name: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
|
||||
def attach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
|
||||
raise NotImplementedError
|
||||
|
||||
def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
|
||||
def detach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_sources(self) -> List[Source]:
|
||||
@@ -397,6 +409,26 @@ class AbstractClient(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def attach_block(self, agent_id: str, block_id: str) -> AgentState:
|
||||
"""
|
||||
Attach a block to an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
block_id (str): ID of the block to attach
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def detach_block(self, agent_id: str, block_id: str) -> AgentState:
|
||||
"""
|
||||
Detach a block from an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
block_id (str): ID of the block to detach
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RESTClient(AbstractClient):
|
||||
"""
|
||||
@@ -620,7 +652,7 @@ class RESTClient(AbstractClient):
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
):
|
||||
) -> AgentState:
|
||||
"""
|
||||
Update an existing agent
|
||||
|
||||
@@ -670,7 +702,7 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to get tools from agents: {response.text}")
|
||||
return [Tool(**tool) for tool in response.json()]
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str):
|
||||
def attach_tool(self, agent_id: str, tool_id: str) -> AgentState:
|
||||
"""
|
||||
Add tool to an existing agent
|
||||
|
||||
@@ -681,12 +713,12 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/add-tool/{tool_id}", headers=self.headers)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools/attach/{tool_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
|
||||
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
|
||||
def detach_tool(self, agent_id: str, tool_id: str) -> AgentState:
|
||||
"""
|
||||
Removes tools from an existing agent
|
||||
|
||||
@@ -698,12 +730,12 @@ class RESTClient(AbstractClient):
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/remove-tool/{tool_id}", headers=self.headers)
|
||||
response = requests.patch(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools/detach/{tool_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
|
||||
def rename_agent(self, agent_id: str, new_name: str):
|
||||
def rename_agent(self, agent_id: str, new_name: str) -> AgentState:
|
||||
"""
|
||||
Rename an agent
|
||||
|
||||
@@ -711,10 +743,12 @@ class RESTClient(AbstractClient):
|
||||
agent_id (str): ID of the agent
|
||||
new_name (str): New name for the agent
|
||||
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
return self.update_agent(agent_id, name=new_name)
|
||||
|
||||
def delete_agent(self, agent_id: str):
|
||||
def delete_agent(self, agent_id: str) -> None:
|
||||
"""
|
||||
Delete an agent
|
||||
|
||||
@@ -1425,7 +1459,7 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to update source: {response.text}")
|
||||
return Source(**response.json())
|
||||
|
||||
def attach_source_to_agent(self, source_id: str, agent_id: str):
|
||||
def attach_source(self, source_id: str, agent_id: str) -> AgentState:
|
||||
"""
|
||||
Attach a source to an agent
|
||||
|
||||
@@ -1435,15 +1469,20 @@ class RESTClient(AbstractClient):
|
||||
source_name (str): Name of the source
|
||||
"""
|
||||
params = {"agent_id": agent_id}
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/attach", params=params, headers=self.headers)
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/sources/attach/{source_id}", params=params, headers=self.headers
|
||||
)
|
||||
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
||||
return AgentState(**response.json())
|
||||
|
||||
def detach_source(self, source_id: str, agent_id: str):
|
||||
def detach_source(self, source_id: str, agent_id: str) -> AgentState:
|
||||
"""Detach a source from an agent"""
|
||||
params = {"agent_id": str(agent_id)}
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/sources/{source_id}/detach", params=params, headers=self.headers)
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/sources/detach/{source_id}", params=params, headers=self.headers
|
||||
)
|
||||
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
||||
return Source(**response.json())
|
||||
return AgentState(**response.json())
|
||||
|
||||
# tools
|
||||
|
||||
@@ -1466,6 +1505,21 @@ class RESTClient(AbstractClient):
|
||||
return None
|
||||
return tools[0].id
|
||||
|
||||
def list_attached_tools(self, agent_id: str) -> List[Tool]:
|
||||
"""
|
||||
List all tools attached to an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
|
||||
Returns:
|
||||
List[Tool]: A list of attached tools
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/tools", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to list attached tools: {response.text}")
|
||||
return [Tool(**tool) for tool in response.json()]
|
||||
|
||||
def upsert_base_tools(self) -> List[Tool]:
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
@@ -1835,66 +1889,36 @@ class RESTClient(AbstractClient):
|
||||
block = self.get_agent_memory_block(agent_id, current_label)
|
||||
return self.update_block(block.id, label=new_label)
|
||||
|
||||
# TODO: remove this
|
||||
def add_agent_memory_block(self, agent_id: str, create_block: CreateBlock) -> Memory:
|
||||
def attach_block(self, agent_id: str, block_id: str) -> AgentState:
|
||||
"""
|
||||
Create and link a memory block to an agent's core memory
|
||||
Attach a block to an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): The agent ID
|
||||
create_block (CreateBlock): The block to create
|
||||
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
agent_id (str): ID of the agent
|
||||
block_id (str): ID of the block to attach
|
||||
"""
|
||||
response = requests.post(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks",
|
||||
headers=self.headers,
|
||||
json=create_block.model_dump(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to add agent memory block: {response.text}")
|
||||
return Memory(**response.json())
|
||||
|
||||
def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory:
|
||||
"""
|
||||
Link a block to an agent's core memory
|
||||
|
||||
Args:
|
||||
agent_id (str): The agent ID
|
||||
block_id (str): The block ID
|
||||
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
"""
|
||||
params = {"agent_id": agent_id}
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/blocks/{block_id}/attach",
|
||||
params=params,
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks/attach/{block_id}",
|
||||
headers=self.headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to link agent memory block: {response.text}")
|
||||
return Block(**response.json())
|
||||
raise ValueError(f"Failed to attach block to agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
|
||||
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
|
||||
def detach_block(self, agent_id: str, block_id: str) -> AgentState:
|
||||
"""
|
||||
Unlike a block from the agent's core memory
|
||||
Detach a block from an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): The agent ID
|
||||
block_label (str): The block label
|
||||
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
agent_id (str): ID of the agent
|
||||
block_id (str): ID of the block to detach
|
||||
"""
|
||||
response = requests.delete(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks/{block_label}",
|
||||
headers=self.headers,
|
||||
response = requests.patch(
|
||||
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/core_memory/blocks/detach/{block_id}", headers=self.headers
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to remove agent memory block: {response.text}")
|
||||
return Memory(**response.json())
|
||||
raise ValueError(f"Failed to detach block from agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
|
||||
def list_agent_memory_blocks(self, agent_id: str) -> List[Block]:
|
||||
"""
|
||||
@@ -2381,7 +2405,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
# TODO: add the abilitty to reset linked block_ids
|
||||
# TODO: add the ability to reset linked block_ids
|
||||
self.interface.clear()
|
||||
agent_state = self.server.agent_manager.update_agent(
|
||||
agent_id,
|
||||
@@ -2413,7 +2437,7 @@ class LocalClient(AbstractClient):
|
||||
self.interface.clear()
|
||||
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str):
|
||||
def attach_tool(self, agent_id: str, tool_id: str) -> AgentState:
|
||||
"""
|
||||
Add tool to an existing agent
|
||||
|
||||
@@ -2428,7 +2452,7 @@ class LocalClient(AbstractClient):
|
||||
agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
|
||||
return agent_state
|
||||
|
||||
def remove_tool_from_agent(self, agent_id: str, tool_id: str):
|
||||
def detach_tool(self, agent_id: str, tool_id: str) -> AgentState:
|
||||
"""
|
||||
Removes tools from an existing agent
|
||||
|
||||
@@ -2443,17 +2467,20 @@ class LocalClient(AbstractClient):
|
||||
agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
|
||||
return agent_state
|
||||
|
||||
def rename_agent(self, agent_id: str, new_name: str):
|
||||
def rename_agent(self, agent_id: str, new_name: str) -> AgentState:
|
||||
"""
|
||||
Rename an agent
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
new_name (str): New name for the agent
|
||||
"""
|
||||
self.update_agent(agent_id, name=new_name)
|
||||
|
||||
def delete_agent(self, agent_id: str):
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
return self.update_agent(agent_id, name=new_name)
|
||||
|
||||
def delete_agent(self, agent_id: str) -> None:
|
||||
"""
|
||||
Delete an agent
|
||||
|
||||
@@ -3028,6 +3055,18 @@ class LocalClient(AbstractClient):
|
||||
tool = self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user)
|
||||
return tool.id if tool else None
|
||||
|
||||
def list_attached_tools(self, agent_id: str) -> List[Tool]:
|
||||
"""
|
||||
List all tools attached to an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
|
||||
Returns:
|
||||
List[Tool]: List of tools attached to the agent
|
||||
"""
|
||||
return self.server.agent_manager.list_attached_tools(agent_id=agent_id, actor=self.user)
|
||||
|
||||
def load_data(self, connector: DataConnector, source_name: str):
|
||||
"""
|
||||
Load data into a source
|
||||
@@ -3061,14 +3100,14 @@ class LocalClient(AbstractClient):
|
||||
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id, actor=self.user)
|
||||
return job
|
||||
|
||||
def delete_file_from_source(self, source_id: str, file_id: str):
|
||||
def delete_file_from_source(self, source_id: str, file_id: str) -> None:
|
||||
self.server.source_manager.delete_file(file_id, actor=self.user)
|
||||
|
||||
def get_job(self, job_id: str):
|
||||
return self.server.job_manager.get_job_by_id(job_id=job_id, actor=self.user)
|
||||
|
||||
def delete_job(self, job_id: str):
|
||||
return self.server.job_manager.delete_job(job_id=job_id, actor=self.user)
|
||||
return self.server.job_manager.delete_job_by_id(job_id=job_id, actor=self.user)
|
||||
|
||||
def list_jobs(self):
|
||||
return self.server.job_manager.list_jobs(actor=self.user)
|
||||
@@ -3127,7 +3166,7 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id
|
||||
|
||||
def attach_source_to_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
|
||||
def attach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
|
||||
"""
|
||||
Attach a source to an agent
|
||||
|
||||
@@ -3140,9 +3179,9 @@ class LocalClient(AbstractClient):
|
||||
source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user)
|
||||
source_id = source.id
|
||||
|
||||
self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user)
|
||||
return self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user)
|
||||
|
||||
def detach_source_from_agent(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None):
|
||||
def detach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState:
|
||||
"""
|
||||
Detach a source from an agent by removing all `Passage` objects that were loaded from the source from archival memory.
|
||||
Args:
|
||||
@@ -3475,51 +3514,7 @@ class LocalClient(AbstractClient):
|
||||
block = self.get_agent_memory_block(agent_id, current_label)
|
||||
return self.update_block(block.id, label=new_label)
|
||||
|
||||
# TODO: remove this
|
||||
def add_agent_memory_block(self, agent_id: str, create_block: CreateBlock) -> Memory:
|
||||
"""
|
||||
Create and link a memory block to an agent's core memory
|
||||
|
||||
Args:
|
||||
agent_id (str): The agent ID
|
||||
create_block (CreateBlock): The block to create
|
||||
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
"""
|
||||
block_req = Block(**create_block.model_dump())
|
||||
block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req)
|
||||
# Link the block to the agent
|
||||
agent = self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=self.user)
|
||||
return agent.memory
|
||||
|
||||
def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory:
|
||||
"""
|
||||
Link a block to an agent's core memory
|
||||
|
||||
Args:
|
||||
agent_id (str): The agent ID
|
||||
block_id (str): The block ID
|
||||
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
"""
|
||||
return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
|
||||
|
||||
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
|
||||
"""
|
||||
Unlike a block from the agent's core memory
|
||||
|
||||
Args:
|
||||
agent_id (str): The agent ID
|
||||
block_label (str): The block label
|
||||
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
"""
|
||||
return self.server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=self.user)
|
||||
|
||||
def list_agent_memory_blocks(self, agent_id: str) -> List[Block]:
|
||||
def get_agent_memory_blocks(self, agent_id: str) -> List[Block]:
|
||||
"""
|
||||
Get all the blocks in the agent's core memory
|
||||
|
||||
@@ -3600,6 +3595,26 @@ class LocalClient(AbstractClient):
|
||||
data["label"] = label
|
||||
return self.server.block_manager.update_block(block_id, actor=self.user, block_update=BlockUpdate(**data))
|
||||
|
||||
def attach_block(self, agent_id: str, block_id: str) -> AgentState:
|
||||
"""
|
||||
Attach a block to an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
block_id (str): ID of the block to attach
|
||||
"""
|
||||
return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
|
||||
|
||||
def detach_block(self, agent_id: str, block_id: str) -> AgentState:
|
||||
"""
|
||||
Detach a block from an agent.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent
|
||||
block_id (str): ID of the block to detach
|
||||
"""
|
||||
return self.server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
|
||||
|
||||
def get_run_messages(
|
||||
self,
|
||||
run_id: str,
|
||||
|
||||
@@ -126,43 +126,63 @@ def get_tools_from_agent(
|
||||
):
|
||||
"""Get tools from an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).tools
|
||||
return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
|
||||
def add_tool_to_agent(
|
||||
@router.patch("/{agent_id}/tools/attach/{tool_id}", response_model=AgentState, operation_id="attach_tool_to_agent")
|
||||
def attach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""Add tools to an existing agent"""
|
||||
"""
|
||||
Attach a tool to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
|
||||
def remove_tool_from_agent(
|
||||
@router.patch("/{agent_id}/tools/detach/{tool_id}", response_model=AgentState, operation_id="detach_tool_from_agent")
|
||||
def detach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""Add tools to an existing agent"""
|
||||
"""
|
||||
Detach a tool from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/reset-messages", response_model=AgentState, operation_id="reset_messages")
|
||||
def reset_messages(
|
||||
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
|
||||
def attach_source(
|
||||
agent_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""Resets the messages for an agent"""
|
||||
"""
|
||||
Attach a source to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
|
||||
return server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/sources/detach/{source_id}", response_model=AgentState, operation_id="detach_source_from_agent")
|
||||
def detach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a source from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
|
||||
@@ -263,49 +283,6 @@ def list_agent_memory_blocks(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{agent_id}/core_memory/blocks", response_model=Memory, operation_id="add_agent_memory_block")
|
||||
def add_agent_memory_block(
|
||||
agent_id: str,
|
||||
create_block: CreateBlock = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Creates a memory block and links it to the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Copied from POST /blocks
|
||||
# TODO: Should have block_manager accept only CreateBlock
|
||||
# TODO: This will be possible once we move ID creation to the ORM
|
||||
block_req = Block(**create_block.model_dump())
|
||||
block = server.block_manager.create_or_update_block(actor=actor, block=block_req)
|
||||
|
||||
# Link the block to the agent
|
||||
agent = server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=actor)
|
||||
return agent.memory
|
||||
|
||||
|
||||
@router.delete("/{agent_id}/core_memory/blocks/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block_by_label")
|
||||
def remove_agent_memory_block(
|
||||
agent_id: str,
|
||||
# TODO should this be block_id, or the label?
|
||||
# I think label is OK since it's user-friendly + guaranteed to be unique within a Memory object
|
||||
block_label: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Unlink the block from the agent
|
||||
agent = server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
|
||||
return agent.memory
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/core_memory/blocks/{block_label}", response_model=Block, operation_id="update_agent_memory_block_by_label")
|
||||
def update_agent_memory_block(
|
||||
agent_id: str,
|
||||
@@ -315,7 +292,7 @@ def update_agent_memory_block(
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
|
||||
Updates a memory block of an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
@@ -328,6 +305,34 @@ def update_agent_memory_block(
|
||||
return block
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/core_memory/blocks/attach/{block_id}", response_model=AgentState, operation_id="attach_block_to_agent")
|
||||
def attach_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a block to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/core_memory/blocks/detach/{block_id}", response_model=AgentState, operation_id="detach_block_from_agent")
|
||||
def detach_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a block from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/archival_memory", response_model=List[Passage], operation_id="list_agent_archival_memory")
|
||||
def get_agent_archival_memory(
|
||||
agent_id: str,
|
||||
@@ -610,3 +615,15 @@ async def send_message_async(
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/reset-messages", response_model=AgentState, operation_id="reset_messages")
|
||||
def reset_messages(
|
||||
agent_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Resets the messages for an agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
@@ -73,41 +73,3 @@ def get_block(
|
||||
return block
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
|
||||
|
||||
@router.patch("/{block_id}/attach", response_model=None, status_code=204, operation_id="link_agent_memory_block")
|
||||
def link_agent_memory_block(
|
||||
block_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Link a memory block to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
try:
|
||||
server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
return Response(status_code=204)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{block_id}/detach", response_model=None, status_code=204, operation_id="unlink_agent_memory_block")
|
||||
def unlink_agent_memory_block(
|
||||
block_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Unlink a memory block from an agent
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
try:
|
||||
server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
return Response(status_code=204)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -111,36 +111,6 @@ def delete_source(
|
||||
server.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.post("/{source_id}/attach", response_model=Source, operation_id="attach_agent_to_source")
|
||||
def attach_source_to_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Attach a data source to an existing agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=actor)
|
||||
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.post("/{source_id}/detach", response_model=Source, operation_id="detach_agent_from_source")
|
||||
def detach_source_from_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
) -> None:
|
||||
"""
|
||||
Detach a data source from an existing agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
return server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
|
||||
def upload_file_to_source(
|
||||
file: UploadFile,
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.block_manager import BlockManager
|
||||
@@ -537,7 +538,7 @@ class AgentManager:
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
|
||||
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a source to an agent.
|
||||
|
||||
@@ -567,6 +568,7 @@ class AgentManager:
|
||||
|
||||
# Commit the changes
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||||
@@ -588,7 +590,7 @@ class AgentManager:
|
||||
return [source.to_pydantic() for source in agent.sources]
|
||||
|
||||
@enforce_types
|
||||
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
|
||||
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a source from an agent.
|
||||
|
||||
@@ -602,10 +604,17 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Remove the source from the relationship
|
||||
agent.sources = [s for s in agent.sources if s.id != source_id]
|
||||
remaining_sources = [s for s in agent.sources if s.id != source_id]
|
||||
|
||||
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
|
||||
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
|
||||
|
||||
# Update the sources relationship
|
||||
agent.sources = remaining_sources
|
||||
|
||||
# Commit the changes
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block management
|
||||
@@ -1011,6 +1020,22 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""
|
||||
List all tools attached to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to list tools for.
|
||||
actor: User performing the action.
|
||||
|
||||
Returns:
|
||||
List[PydanticTool]: List of tools attached to the agent.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
return [tool.to_pydantic() for tool in agent.tools]
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tag Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import delete
|
||||
from letta import LocalClient, RESTClient, create_client
|
||||
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.job import JobStatus
|
||||
@@ -113,46 +112,6 @@ def clear_tables():
|
||||
session.commit()
|
||||
|
||||
|
||||
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
|
||||
# _reset_config()
|
||||
|
||||
# create a block
|
||||
block = client.create_block(label="human", value="username: sarah")
|
||||
|
||||
# create agents with shared block
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.memory import BasicBlockMemory
|
||||
|
||||
# persona1_block = client.create_block(label="persona", value="you are agent 1")
|
||||
# persona2_block = client.create_block(label="persona", value="you are agent 2")
|
||||
# create agents
|
||||
agent_state1 = client.create_agent(
|
||||
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
|
||||
)
|
||||
agent_state2 = client.create_agent(
|
||||
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
|
||||
)
|
||||
|
||||
## attach shared block to both agents
|
||||
# client.link_agent_memory_block(agent_state1.id, block.id)
|
||||
# client.link_agent_memory_block(agent_state2.id, block.id)
|
||||
|
||||
# update memory
|
||||
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
|
||||
|
||||
# check agent 2 memory
|
||||
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
|
||||
|
||||
client.user_message(agent_id=agent_state2.id, message="whats my name?")
|
||||
assert (
|
||||
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
|
||||
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
|
||||
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
client.delete_agent(agent_state2.id)
|
||||
|
||||
|
||||
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
|
||||
"""
|
||||
Test sandbox config and environment variable functions for both LocalClient and RESTClient.
|
||||
@@ -204,6 +163,11 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]
|
||||
client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# Agent tags
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]):
|
||||
"""
|
||||
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
|
||||
@@ -306,6 +270,49 @@ def test_agent_tags(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent3.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# Agent memory blocks
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
|
||||
# _reset_config()
|
||||
|
||||
# create a block
|
||||
block = client.create_block(label="human", value="username: sarah")
|
||||
|
||||
# create agents with shared block
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.memory import BasicBlockMemory
|
||||
|
||||
# persona1_block = client.create_block(label="persona", value="you are agent 1")
|
||||
# persona2_block = client.create_block(label="persona", value="you are agent 2")
|
||||
# create agents
|
||||
agent_state1 = client.create_agent(
|
||||
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
|
||||
)
|
||||
agent_state2 = client.create_agent(
|
||||
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
|
||||
)
|
||||
|
||||
## attach shared block to both agents
|
||||
# client.link_agent_memory_block(agent_state1.id, block.id)
|
||||
# client.link_agent_memory_block(agent_state2.id, block.id)
|
||||
|
||||
# update memory
|
||||
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
|
||||
|
||||
# check agent 2 memory
|
||||
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
|
||||
|
||||
client.user_message(agent_id=agent_state2.id, message="whats my name?")
|
||||
assert (
|
||||
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
|
||||
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
|
||||
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
client.delete_agent(agent_state2.id)
|
||||
|
||||
|
||||
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can update the label of a block in an agent's memory"""
|
||||
|
||||
@@ -326,38 +333,32 @@ def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_attach_detach_agent_memory_block(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can add and remove a block from an agent's memory"""
|
||||
|
||||
agent = client.create_agent(name=create_random_username())
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
example_new_label = current_labels[0] + "_v2"
|
||||
example_new_value = "example value"
|
||||
assert example_new_label not in current_labels
|
||||
|
||||
try:
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
example_new_label = "example_new_label"
|
||||
example_new_value = "example value"
|
||||
assert example_new_label not in current_labels
|
||||
# Link a new memory block
|
||||
block = client.create_block(
|
||||
label=example_new_label,
|
||||
value=example_new_value,
|
||||
limit=1000,
|
||||
)
|
||||
updated_agent = client.attach_block(
|
||||
agent_id=agent.id,
|
||||
block_id=block.id,
|
||||
)
|
||||
assert example_new_label in updated_agent.memory.list_block_labels()
|
||||
|
||||
# Link a new memory block
|
||||
client.add_agent_memory_block(
|
||||
agent_id=agent.id,
|
||||
create_block=CreateBlock(
|
||||
label=example_new_label,
|
||||
value=example_new_value,
|
||||
limit=1000,
|
||||
),
|
||||
)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_label in updated_agent.memory.list_block_labels()
|
||||
|
||||
# Now unlink the block
|
||||
client.remove_agent_memory_block(agent_id=agent.id, block_label=example_new_label)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_label not in updated_agent.memory.list_block_labels()
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent.id)
|
||||
# Now unlink the block
|
||||
updated_agent = client.detach_block(
|
||||
agent_id=agent.id,
|
||||
block_id=block.id,
|
||||
)
|
||||
assert example_new_label not in updated_agent.memory.list_block_labels()
|
||||
|
||||
|
||||
# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
@@ -413,24 +414,9 @@ def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
assert send_message_response, "Sending message failed"
|
||||
|
||||
messages_response = client.get_messages(agent_id=agent.id, limit=1)
|
||||
assert len(messages_response) > 0, "Retrieving messages failed"
|
||||
|
||||
|
||||
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
|
||||
send_system_message_response = client.send_message(
|
||||
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
|
||||
)
|
||||
assert send_system_message_response, "Sending message failed"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# Agent Tools
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
"""Test to see if the function return limit works"""
|
||||
|
||||
@@ -503,6 +489,70 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
|
||||
|
||||
def test_attach_detach_agent_tool(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can attach and detach a tool from an agent"""
|
||||
|
||||
try:
|
||||
# Create a tool
|
||||
def example_tool(x: int) -> int:
|
||||
"""
|
||||
This is an example tool.
|
||||
|
||||
Parameters:
|
||||
x (int): The input value.
|
||||
|
||||
Returns:
|
||||
int: The output value.
|
||||
"""
|
||||
return x * 2
|
||||
|
||||
tool = client.create_or_update_tool(func=example_tool, name="test_tool")
|
||||
|
||||
# Initially tool should not be attached
|
||||
initial_tools = client.list_attached_tools(agent_id=agent.id)
|
||||
assert tool.id not in [t.id for t in initial_tools]
|
||||
|
||||
# Attach tool
|
||||
new_agent_state = client.attach_tool(agent_id=agent.id, tool_id=tool.id)
|
||||
assert tool.id in [t.id for t in new_agent_state.tools]
|
||||
|
||||
# Verify tool is attached
|
||||
updated_tools = client.list_attached_tools(agent_id=agent.id)
|
||||
assert tool.id in [t.id for t in updated_tools]
|
||||
|
||||
# Detach tool
|
||||
new_agent_state = client.detach_tool(agent_id=agent.id, tool_id=tool.id)
|
||||
assert tool.id not in [t.id for t in new_agent_state.tools]
|
||||
|
||||
# Verify tool is detached
|
||||
final_tools = client.list_attached_tools(agent_id=agent.id)
|
||||
assert tool.id not in [t.id for t in final_tools]
|
||||
|
||||
finally:
|
||||
client.delete_tool(tool.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# AgentMessages
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
assert send_message_response, "Sending message failed"
|
||||
|
||||
messages_response = client.get_messages(agent_id=agent.id, limit=1)
|
||||
assert len(messages_response) > 0, "Retrieving messages failed"
|
||||
|
||||
|
||||
def test_send_system_message(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Important unit test since the Letta API exposes sending system messages, but some backends don't natively support it (eg Anthropic)"""
|
||||
send_system_message_response = client.send_message(
|
||||
agent_id=agent.id, message="Event occurred: The user just logged off.", role="system"
|
||||
)
|
||||
assert send_system_message_response, "Sending message failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
|
||||
"""
|
||||
@@ -580,9 +630,9 @@ def test_send_message_async(client: Union[LocalClient, RESTClient], agent: Agent
|
||||
assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens
|
||||
|
||||
|
||||
# ==========================================
|
||||
# TESTS FOR AGENT LISTING
|
||||
# ==========================================
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
# Agent listing
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two):
|
||||
@@ -678,3 +728,33 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]):
|
||||
assert all(tool.id in tool_ids for tool in agent_tools)
|
||||
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
# Agent sources
|
||||
# --------------------------------------------------------------------------------------------------------------------
|
||||
def test_attach_detach_agent_source(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can attach and detach a source from an agent"""
|
||||
|
||||
# Create a source
|
||||
source = client.create_source(
|
||||
name="test_source",
|
||||
)
|
||||
initial_sources = client.list_attached_sources(agent_id=agent.id)
|
||||
assert source.id not in [s.id for s in initial_sources]
|
||||
|
||||
# Attach source
|
||||
client.attach_source(agent_id=agent.id, source_id=source.id)
|
||||
|
||||
# Verify source is attached
|
||||
final_sources = client.list_attached_sources(agent_id=agent.id)
|
||||
assert source.id in [s.id for s in final_sources]
|
||||
|
||||
# Detach source
|
||||
client.detach_source(agent_id=agent.id, source_id=source.id)
|
||||
|
||||
# Verify source is detached
|
||||
final_sources = client.list_attached_sources(agent_id=agent.id)
|
||||
assert source.id not in [s.id for s in final_sources]
|
||||
|
||||
client.delete_source(source.id)
|
||||
|
||||
@@ -205,9 +205,9 @@ def test_archival_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTC
|
||||
passages = client.get_archival_memory(agent.id)
|
||||
assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}"
|
||||
|
||||
# get archival memory summary
|
||||
archival_summary = client.get_archival_memory_summary(agent.id)
|
||||
assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}"
|
||||
# # get archival memory summary
|
||||
# archival_summary = client.get_agent_archival_memory_summary(agent.id)
|
||||
# assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}"
|
||||
|
||||
# delete archival memory
|
||||
client.delete_archival_memory(agent.id, passage.id)
|
||||
@@ -500,7 +500,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
assert len(archival_memories) == 0
|
||||
|
||||
# attach a source
|
||||
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
|
||||
client.attach_source(source_id=source.id, agent_id=agent.id)
|
||||
|
||||
# list attached sources
|
||||
attached_sources = client.list_attached_sources(agent_id=agent.id)
|
||||
@@ -521,8 +521,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
# detach the source
|
||||
assert len(client.get_archival_memory(agent_id=agent.id)) > 0, "No archival memory"
|
||||
deleted_source = client.detach_source(source_id=source.id, agent_id=agent.id)
|
||||
assert deleted_source.id == source.id
|
||||
client.detach_source(source_id=source.id, agent_id=agent.id)
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}"
|
||||
assert source.id not in [s.id for s in client.list_attached_sources(agent.id)]
|
||||
|
||||
@@ -141,7 +141,7 @@ def test_agent_add_remove_tools(client: LocalClient, agent):
|
||||
curr_num_tools = len(agent_state.tools)
|
||||
|
||||
# add both tools to agent in steps
|
||||
agent_state = client.add_tool_to_agent(agent_id=agent_state.id, tool_id=github_tool.id)
|
||||
agent_state = client.attach_tool(agent_id=agent_state.id, tool_id=github_tool.id)
|
||||
|
||||
# confirm that both tools are in the agent state
|
||||
# we could access it like agent_state.tools, but will use the client function instead
|
||||
@@ -153,7 +153,7 @@ def test_agent_add_remove_tools(client: LocalClient, agent):
|
||||
assert github_tool.name in curr_tool_names
|
||||
|
||||
# remove only the github tool
|
||||
agent_state = client.remove_tool_from_agent(agent_id=agent_state.id, tool_id=github_tool.id)
|
||||
agent_state = client.detach_tool(agent_id=agent_state.id, tool_id=github_tool.id)
|
||||
|
||||
# confirm that only one tool left
|
||||
curr_tools = client.get_tools_from_agent(agent_state.id)
|
||||
|
||||
Reference in New Issue
Block a user