chore: Clean up .load_agent usage (#2298)

This commit is contained in:
Matthew Zhou
2024-12-20 16:56:53 -08:00
committed by GitHub
parent a5b1aac1fd
commit 9ad5fd64cf
10 changed files with 134 additions and 164 deletions

View File

@@ -59,7 +59,7 @@ from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUp
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@@ -303,11 +303,6 @@ class SyncServer(Server):
self.block_manager.add_default_blocks(actor=self.default_user)
self.tool_manager.upsert_base_tools(actor=self.default_user)
# If there is a default org/user
# This logic may have to change in the future
if settings.load_default_external_tools:
self.add_default_external_tools(actor=self.default_user)
# collect providers (always has Letta as a default)
self._enabled_providers: List[Provider] = [LettaProvider()]
if model_settings.openai_api_key:
@@ -431,9 +426,6 @@ class SyncServer(Server):
skip_verify=True,
)
# save agent after step
# save_agent(letta_agent)
except Exception as e:
logger.error(f"Error in server._step: {e}")
print(traceback.print_exc())
@@ -944,11 +936,10 @@ class SyncServer(Server):
agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent_state in agent_states:
agent_id = agent_state.id
agent = self.load_agent(agent_id=agent_id, actor=actor)
# Attach source to agent
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor)
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added
@@ -973,56 +964,6 @@ class SyncServer(Server):
passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
return passage_count, document_count
def attach_source_to_agent(
self,
user_id: str,
agent_id: str,
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
# attach a data source to an agent
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
if source_id:
data_source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
assert data_source, f"Data source with id={source_id} or name={source_name} does not exist"
# load agent
agent = self.load_agent(agent_id=agent_id, actor=actor)
# attach source to agent
agent.attach_source(user=actor, source_id=data_source.id, source_manager=self.source_manager, agent_manager=self.agent_manager)
return data_source
def detach_source_from_agent(
self,
user_id: str,
agent_id: str,
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
if source_id:
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
source_id = source.id
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
# delete agent-source mapping
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
# return back source data
return source
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
return []
@@ -1060,22 +1001,6 @@ class SyncServer(Server):
return sources_with_metadata
def add_default_external_tools(self, actor: User) -> bool:
"""Add default langchain tools. Return true if successful, false otherwise."""
success = True
tool_creates = ToolCreate.load_default_langchain_tools()
if tool_settings.composio_api_key:
tool_creates += ToolCreate.load_default_composio_tools()
for tool_create in tool_creates:
try:
self.tool_manager.create_or_update_tool(Tool(**tool_create.model_dump()), actor=actor)
except Exception as e:
warnings.warn(f"An error occurred while creating tool {tool_create}: {e}")
warnings.warn(traceback.format_exc())
success = False
return success
def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
"""Update the details of a message associated with an agent"""