feat: Rewrite agents (#2232)

This commit is contained in:
Matthew Zhou
2024-12-13 14:43:19 -08:00
committed by GitHub
parent 65fd731917
commit 7908b8a15f
86 changed files with 2495 additions and 3980 deletions

View File

@@ -15,7 +15,8 @@ from letta.constants import (
)
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona
from letta.schemas.embedding_config import EmbeddingConfig
@@ -65,10 +66,8 @@ def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
class AbstractClient(object):
def __init__(
self,
auto_save: bool = False,
debug: bool = False,
):
self.auto_save = auto_save
self.debug = debug
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
@@ -81,8 +80,9 @@ class AbstractClient(object):
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
memory=None,
block_ids: Optional[List[str]] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
@@ -97,7 +97,7 @@ class AbstractClient(object):
name: Optional[str] = None,
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
@@ -436,7 +436,6 @@ class RESTClient(AbstractClient):
Initializes a new instance of Client class.
Args:
auto_save (bool): Whether to automatically save changes.
user_id (str): The user ID.
debug (bool): Whether to print debug information.
default_llm_config (Optional[LLMConfig]): The default LLM configuration.
@@ -456,6 +455,7 @@ class RESTClient(AbstractClient):
params = {}
if tags:
params["tags"] = tags
params["match_all_tags"] = False
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
return [AgentState(**agent) for agent in response.json()]
@@ -491,10 +491,12 @@ class RESTClient(AbstractClient):
llm_config: LLMConfig = None,
# memory
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# Existing blocks
block_ids: Optional[List[str]] = None,
# system
system: Optional[str] = None,
# tools
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
# metadata
@@ -511,7 +513,7 @@ class RESTClient(AbstractClient):
llm_config (LLMConfig): LLM configuration
memory (Memory): Memory configuration
system (str): System configuration
tools (List[str]): List of tools
tool_ids (List[str]): List of tool ids
include_base_tools (bool): Include base tools
metadata (Dict): Metadata
description (str): Description
@@ -520,31 +522,54 @@ class RESTClient(AbstractClient):
Returns:
agent_state (AgentState): State of the created agent
"""
tool_ids = tool_ids or []
tool_names = []
if tools:
tool_names += tools
if include_base_tools:
tool_names += BASE_TOOLS
tool_names += BASE_MEMORY_TOOLS
tool_ids += [self.get_tool_id(tool_name=name) for name in tool_names]
assert embedding_config or self._default_embedding_config, f"Embedding config must be provided"
assert llm_config or self._default_llm_config, f"LLM config must be provided"
# TODO: This should not happen here, we need to have clear separation between create/add blocks
# TODO: This is insanely hacky and a result of allowing free-floating blocks
# TODO: When we create the block, it gets it's own block ID
blocks = []
for block in memory.get_blocks():
blocks.append(
self.create_block(
label=block.label,
value=block.value,
limit=block.limit,
template_name=block.template_name,
is_template=block.is_template,
)
)
memory.blocks = blocks
block_ids = block_ids or []
# create agent
request = CreateAgent(
name=name,
description=description,
metadata_=metadata,
memory_blocks=[],
tools=tool_names,
tool_rules=tool_rules,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
tags=tags,
)
create_params = {
"description": description,
"metadata_": metadata,
"memory_blocks": [],
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
"tool_rules": tool_rules,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,
"embedding_config": embedding_config if embedding_config else self._default_embedding_config,
"initial_message_sequence": initial_message_sequence,
"tags": tags,
}
# Only add name if it's not None
if name is not None:
create_params["name"] = name
request = CreateAgent(**create_params)
# Use model_dump_json() instead of model_dump()
# If we use model_dump(), the datetime objects will not be serialized correctly
@@ -561,14 +586,6 @@ class RESTClient(AbstractClient):
# gather agent state
agent_state = AgentState(**response.json())
# create and link blocks
for block in memory.get_blocks():
if not self.get_block(block.id):
# note: this does not update existing blocks
# WARNING: this resets the block ID - this method is a hack for backwards compat, should eventually use CreateBlock not Memory
block = self.create_block(label=block.label, value=block.value, limit=block.limit)
self.link_agent_memory_block(agent_id=agent_state.id, block_id=block.id)
# refresh and return agent
return self.get_agent(agent_state.id)
@@ -602,7 +619,7 @@ class RESTClient(AbstractClient):
name: Optional[str] = None,
description: Optional[str] = None,
system: Optional[str] = None,
tool_names: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
@@ -617,7 +634,7 @@ class RESTClient(AbstractClient):
name (str): Name of the agent
description (str): Description of the agent
system (str): System configuration
tool_names (List[str]): List of tools
tool_ids (List[str]): List of tools
metadata (Dict): Metadata
llm_config (LLMConfig): LLM configuration
embedding_config (EmbeddingConfig): Embedding configuration
@@ -627,11 +644,10 @@ class RESTClient(AbstractClient):
Returns:
agent_state (AgentState): State of the updated agent
"""
request = UpdateAgentState(
id=agent_id,
request = UpdateAgent(
name=name,
system=system,
tool_names=tool_names,
tool_ids=tool_ids,
tags=tags,
description=description,
metadata_=metadata,
@@ -742,7 +758,7 @@ class RESTClient(AbstractClient):
agents = [AgentState(**agent) for agent in response.json()]
if len(agents) == 0:
return None
agents = [agents[0]] # TODO: @matt monkeypatched
agents = [agents[0]] # TODO: @matt monkeypatched
assert len(agents) == 1, f"Multiple agents with the same name: {[(agents.name, agents.id) for agents in agents]}"
return agents[0].id
@@ -1052,7 +1068,7 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to update block: {response.text}")
return Block(**response.json())
def get_block(self, block_id: str) -> Block:
def get_block(self, block_id: str) -> Optional[Block]:
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", headers=self.headers)
if response.status_code == 404:
return None
@@ -1607,23 +1623,6 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to get tool: {response.text}")
return Tool(**response.json())
def get_tool_id(self, name: str) -> Optional[str]:
"""
Get a tool ID by its name.
Args:
id (str): ID of the tool
Returns:
tool (Tool): Tool
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return response.json()
def set_default_llm_config(self, llm_config: LLMConfig):
"""
Set the default LLM configuration
@@ -2006,7 +2005,6 @@ class LocalClient(AbstractClient):
A local client for Letta, which corresponds to a single user.
Attributes:
auto_save (bool): Whether to automatically save changes.
user_id (str): The user ID.
debug (bool): Whether to print debug information.
interface (QueuingInterface): The interface for the client.
@@ -2015,7 +2013,6 @@ class LocalClient(AbstractClient):
def __init__(
self,
auto_save: bool = False,
user_id: Optional[str] = None,
org_id: Optional[str] = None,
debug: bool = False,
@@ -2026,11 +2023,9 @@ class LocalClient(AbstractClient):
Initializes a new instance of Client class.
Args:
auto_save (bool): Whether to automatically save changes.
user_id (str): The user ID.
debug (bool): Whether to print debug information.
"""
self.auto_save = auto_save
# set logging levels
letta.utils.DEBUG = debug
@@ -2056,14 +2051,14 @@ class LocalClient(AbstractClient):
# get default user
self.user_id = self.server.user_manager.DEFAULT_USER_ID
self.user = self.server.get_user_or_default(self.user_id)
self.user = self.server.user_manager.get_user_or_default(self.user_id)
self.organization = self.server.get_organization_or_default(self.org_id)
# agents
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
self.interface.clear()
return self.server.list_agents(user_id=self.user_id, tags=tags)
return self.server.agent_manager.list_agents(actor=self.user, tags=tags)
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
"""
@@ -2097,6 +2092,7 @@ class LocalClient(AbstractClient):
llm_config: LLMConfig = None,
# memory
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
block_ids: Optional[List[str]] = None,
# TODO: change to this when we are ready to migrate all the tests/examples (matches the REST API)
# memory_blocks=[
# {"label": "human", "value": get_human_text(DEFAULT_HUMAN), "limit": 5000},
@@ -2105,7 +2101,7 @@ class LocalClient(AbstractClient):
# system
system: Optional[str] = None,
# tools
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
# metadata
@@ -2132,55 +2128,53 @@ class LocalClient(AbstractClient):
Returns:
agent_state (AgentState): State of the created agent
"""
if name and self.agent_exists(agent_name=name):
raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})")
# construct list of tools
tool_ids = tool_ids or []
tool_names = []
if tools:
tool_names += tools
if include_base_tools:
tool_names += BASE_TOOLS
tool_names += BASE_MEMORY_TOOLS
tool_ids += [self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user).id for name in tool_names]
# check if default configs are provided
assert embedding_config or self._default_embedding_config, f"Embedding config must be provided"
assert llm_config or self._default_llm_config, f"LLM config must be provided"
# TODO: This should not happen here, we need to have clear separation between create/add blocks
for block in memory.get_blocks():
self.server.block_manager.create_or_update_block(block, actor=self.user)
# Also get any existing block_ids passed in
block_ids = block_ids or []
# create agent
# Create the base parameters
create_params = {
"description": description,
"metadata_": metadata,
"memory_blocks": [],
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
"tool_rules": tool_rules,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,
"embedding_config": embedding_config if embedding_config else self._default_embedding_config,
"initial_message_sequence": initial_message_sequence,
"tags": tags,
}
# Only add name if it's not None
if name is not None:
create_params["name"] = name
agent_state = self.server.create_agent(
CreateAgent(
name=name,
description=description,
metadata_=metadata,
# memory=memory,
memory_blocks=[],
# memory_blocks = memory.get_blocks(),
# memory_tools=memory_tools,
tools=tool_names,
tool_rules=tool_rules,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
tags=tags,
),
CreateAgent(**create_params),
actor=self.user,
)
# TODO: remove when we fully migrate to block creation CreateAgent model
# Link additional blocks to the agent (block ids created on the client)
# This needs to happen since the create agent does not allow passing in blocks which have already been persisted and have an ID
# So we create the agent and then link the blocks afterwards
user = self.server.get_user_or_default(self.user_id)
for block in memory.get_blocks():
self.server.block_manager.create_or_update_block(block, actor=user)
self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_state.id, block_id=block.id)
# TODO: get full agent state
return self.server.get_agent(agent_state.id)
return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user)
def update_message(
self,
@@ -2202,6 +2196,7 @@ class LocalClient(AbstractClient):
tool_calls=tool_calls,
tool_call_id=tool_call_id,
),
actor=self.user,
)
return message
@@ -2211,7 +2206,7 @@ class LocalClient(AbstractClient):
name: Optional[str] = None,
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
@@ -2239,11 +2234,11 @@ class LocalClient(AbstractClient):
# TODO: add the abilitty to reset linked block_ids
self.interface.clear()
agent_state = self.server.update_agent(
UpdateAgentState(
id=agent_id,
agent_id,
UpdateAgent(
name=name,
system=system,
tool_names=tools,
tool_ids=tool_ids,
tags=tags,
description=description,
metadata_=metadata,
@@ -2315,7 +2310,7 @@ class LocalClient(AbstractClient):
Args:
agent_id (str): ID of the agent to delete
"""
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
self.server.agent_manager.delete_agent(agent_id=agent_id, actor=self.user)
def get_agent_by_name(self, agent_name: str) -> AgentState:
"""
@@ -2328,7 +2323,7 @@ class LocalClient(AbstractClient):
agent_state (AgentState): State of the agent
"""
self.interface.clear()
return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None)
return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user)
def get_agent(self, agent_id: str) -> AgentState:
"""
@@ -2340,9 +2335,8 @@ class LocalClient(AbstractClient):
Returns:
agent_state (AgentState): State representation of the agent
"""
# TODO: include agent_name
self.interface.clear()
return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id)
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
def get_agent_id(self, agent_name: str) -> Optional[str]:
"""
@@ -2357,7 +2351,12 @@ class LocalClient(AbstractClient):
self.interface.clear()
assert agent_name, f"Agent name must be provided"
return self.server.get_agent_id(name=agent_name, user_id=self.user_id)
# TODO: Refactor this futher to not have downstream users expect Optionals - this should just error
try:
return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user).id
except NoResultFound:
return None
# memory
def get_in_context_memory(self, agent_id: str) -> Memory:
@@ -2370,7 +2369,7 @@ class LocalClient(AbstractClient):
Returns:
memory (Memory): In-context memory of the agent
"""
memory = self.server.get_agent_memory(agent_id=agent_id)
memory = self.server.get_agent_memory(agent_id=agent_id, actor=self.user)
return memory
def get_core_memory(self, agent_id: str) -> Memory:
@@ -2388,7 +2387,7 @@ class LocalClient(AbstractClient):
"""
# TODO: implement this (not sure what it should look like)
memory = self.server.update_agent_core_memory(user_id=self.user_id, agent_id=agent_id, label=section, value=value)
memory = self.server.update_agent_core_memory(agent_id=agent_id, label=section, value=value, actor=self.user)
return memory
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
@@ -2402,7 +2401,7 @@ class LocalClient(AbstractClient):
summary (ArchivalMemorySummary): Summary of the archival memory
"""
return self.server.get_archival_memory_summary(agent_id=agent_id)
return self.server.get_archival_memory_summary(agent_id=agent_id, actor=self.user)
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
"""
@@ -2414,7 +2413,7 @@ class LocalClient(AbstractClient):
Returns:
summary (RecallMemorySummary): Summary of the recall memory
"""
return self.server.get_recall_memory_summary(agent_id=agent_id)
return self.server.get_recall_memory_summary(agent_id=agent_id, actor=self.user)
def get_in_context_messages(self, agent_id: str) -> List[Message]:
"""
@@ -2426,7 +2425,7 @@ class LocalClient(AbstractClient):
Returns:
messages (List[Message]): List of in-context messages
"""
return self.server.get_in_context_messages(agent_id=agent_id)
return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user)
# agent interactions
@@ -2446,11 +2445,7 @@ class LocalClient(AbstractClient):
response (LettaResponse): Response from the agent
"""
self.interface.clear()
usage = self.server.send_messages(user_id=self.user_id, agent_id=agent_id, messages=messages)
# auto-save
if self.auto_save:
self.save()
usage = self.server.send_messages(actor=self.user, agent_id=agent_id, messages=messages)
# format messages
return LettaResponse(messages=messages, usage=usage)
@@ -2490,15 +2485,11 @@ class LocalClient(AbstractClient):
self.interface.clear()
usage = self.server.send_messages(
user_id=self.user_id,
actor=self.user,
agent_id=agent_id,
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
)
# auto-save
if self.auto_save:
self.save()
## TODO: need to make sure date/timestamp is propely passed
## TODO: update self.interface.to_list() to return actual Message objects
## here, the message objects will have faulty created_by timestamps
@@ -2547,16 +2538,9 @@ class LocalClient(AbstractClient):
self.interface.clear()
usage = self.server.run_command(user_id=self.user_id, agent_id=agent_id, command=command)
# auto-save
if self.auto_save:
self.save()
# NOTE: messages/usage may be empty, depending on the command
return LettaResponse(messages=self.interface.to_list(), usage=usage)
def save(self):
self.server.save_agents()
# archival memory
# humans / personas
@@ -3036,7 +3020,7 @@ class LocalClient(AbstractClient):
Returns:
sources (List[Source]): List of sources
"""
return self.server.list_attached_sources(agent_id=agent_id)
return self.server.agent_manager.list_attached_sources(agent_id=agent_id, actor=self.user)
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
@@ -3080,7 +3064,7 @@ class LocalClient(AbstractClient):
Returns:
passages (List[Passage]): List of inserted passages
"""
return self.server.insert_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_contents=memory)
return self.server.insert_archival_memory(agent_id=agent_id, memory_contents=memory, actor=self.user)
def delete_archival_memory(self, agent_id: str, memory_id: str):
"""
@@ -3090,7 +3074,7 @@ class LocalClient(AbstractClient):
agent_id (str): ID of the agent
memory_id (str): ID of the memory
"""
self.server.delete_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_id=memory_id)
self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user)
def get_archival_memory(
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
@@ -3349,8 +3333,8 @@ class LocalClient(AbstractClient):
block_req = Block(**create_block.model_dump())
block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req)
# Link the block to the agent
updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id)
return updated_memory
agent = self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=self.user)
return agent.memory
def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory:
"""
@@ -3363,7 +3347,7 @@ class LocalClient(AbstractClient):
Returns:
memory (Memory): The updated memory
"""
return self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block_id)
return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
"""
@@ -3376,7 +3360,7 @@ class LocalClient(AbstractClient):
Returns:
memory (Memory): The updated memory
"""
return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label)
return self.server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=self.user)
def get_agent_memory_blocks(self, agent_id: str) -> List[Block]:
"""
@@ -3388,8 +3372,8 @@ class LocalClient(AbstractClient):
Returns:
blocks (List[Block]): The blocks in the agent's core memory
"""
block_ids = self.server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id)
return [self.server.block_manager.get_block_by_id(block_id, actor=self.user) for block_id in block_ids]
agent = self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
return agent.memory.blocks
def get_agent_memory_block(self, agent_id: str, label: str) -> Block:
"""
@@ -3402,8 +3386,7 @@ class LocalClient(AbstractClient):
Returns:
block (Block): The block corresponding to the label
"""
block_id = self.server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=label)
return self.server.block_manager.get_block_by_id(block_id, actor=self.user)
return self.server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=self.user)
def update_agent_memory_block(
self,