257 lines
9.5 KiB
Python
257 lines
9.5 KiB
Python
import asyncio
|
|
import threading
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
from letta.agent import Agent, AgentState
|
|
from letta.groups.helpers import stringify_message
|
|
from letta.interface import AgentInterface
|
|
from letta.orm import User
|
|
from letta.schemas.enums import JobStatus
|
|
from letta.schemas.job import JobUpdate
|
|
from letta.schemas.letta_message_content import TextContent
|
|
from letta.schemas.message import Message, MessageCreate
|
|
from letta.schemas.run import Run
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.server.rest_api.interface import StreamingServerInterface
|
|
from letta.services.group_manager import GroupManager
|
|
from letta.services.job_manager import JobManager
|
|
from letta.services.message_manager import MessageManager
|
|
|
|
|
|
class BackgroundMultiAgent(Agent):
|
|
|
|
def __init__(
|
|
self,
|
|
interface: AgentInterface,
|
|
agent_state: AgentState,
|
|
user: User,
|
|
# custom
|
|
group_id: str = "",
|
|
agent_ids: List[str] = [],
|
|
description: str = "",
|
|
background_agents_frequency: Optional[int] = None,
|
|
):
|
|
super().__init__(interface, agent_state, user)
|
|
self.group_id = group_id
|
|
self.agent_ids = agent_ids
|
|
self.description = description
|
|
self.background_agents_frequency = background_agents_frequency
|
|
self.group_manager = GroupManager()
|
|
self.message_manager = MessageManager()
|
|
self.job_manager = JobManager()
|
|
|
|
def _run_async_in_new_thread(self, coro):
|
|
"""Run an async coroutine in a new thread with its own event loop"""
|
|
result = None
|
|
|
|
def run_async():
|
|
nonlocal result
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
result = loop.run_until_complete(coro)
|
|
finally:
|
|
loop.close()
|
|
asyncio.set_event_loop(None)
|
|
|
|
thread = threading.Thread(target=run_async)
|
|
thread.start()
|
|
thread.join()
|
|
return result
|
|
|
|
async def _issue_background_task(
|
|
self,
|
|
participant_agent_id: str,
|
|
messages: List[Message],
|
|
chaining: bool,
|
|
max_chaining_steps: Optional[int],
|
|
token_streaming: bool,
|
|
metadata: Optional[dict],
|
|
put_inner_thoughts_first: bool,
|
|
last_processed_message_id: str,
|
|
) -> str:
|
|
run = Run(
|
|
user_id=self.user.id,
|
|
status=JobStatus.created,
|
|
metadata={
|
|
"job_type": "background_agent_send_message_async",
|
|
"agent_id": participant_agent_id,
|
|
},
|
|
)
|
|
run = self.job_manager.create_job(pydantic_job=run, actor=self.user)
|
|
|
|
asyncio.create_task(
|
|
self._perform_background_agent_step(
|
|
participant_agent_id=participant_agent_id,
|
|
messages=messages,
|
|
chaining=chaining,
|
|
max_chaining_steps=max_chaining_steps,
|
|
token_streaming=token_streaming,
|
|
metadata=metadata,
|
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
|
last_processed_message_id=last_processed_message_id,
|
|
run_id=run.id,
|
|
)
|
|
)
|
|
|
|
return run.id
|
|
|
|
async def _perform_background_agent_step(
|
|
self,
|
|
participant_agent_id: str,
|
|
messages: List[Message],
|
|
chaining: bool,
|
|
max_chaining_steps: Optional[int],
|
|
token_streaming: bool,
|
|
metadata: Optional[dict],
|
|
put_inner_thoughts_first: bool,
|
|
last_processed_message_id: str,
|
|
run_id: str,
|
|
) -> LettaUsageStatistics:
|
|
try:
|
|
participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user)
|
|
participant_agent = Agent(
|
|
agent_state=participant_agent_state,
|
|
interface=StreamingServerInterface(),
|
|
user=self.user,
|
|
)
|
|
|
|
prior_messages = []
|
|
if self.background_agents_frequency:
|
|
try:
|
|
prior_messages = self.message_manager.list_messages_for_agent(
|
|
agent_id=self.agent_state.id,
|
|
actor=self.user,
|
|
after=last_processed_message_id,
|
|
before=messages[0].id,
|
|
)
|
|
except Exception as e:
|
|
print(f"Error fetching prior messages: {str(e)}")
|
|
# continue with just latest messages
|
|
|
|
transcript_summary = [stringify_message(message) for message in prior_messages + messages]
|
|
transcript_summary = [summary for summary in transcript_summary if summary is not None]
|
|
message_text = "\n".join(transcript_summary)
|
|
|
|
participant_agent_messages = [
|
|
Message(
|
|
id=Message.generate_id(),
|
|
agent_id=participant_agent.agent_state.id,
|
|
role="user",
|
|
content=[TextContent(text=message_text)],
|
|
group_id=self.group_id,
|
|
)
|
|
]
|
|
result = participant_agent.step(
|
|
messages=participant_agent_messages,
|
|
chaining=chaining,
|
|
max_chaining_steps=max_chaining_steps,
|
|
stream=token_streaming,
|
|
skip_verify=True,
|
|
metadata=metadata,
|
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
|
)
|
|
job_update = JobUpdate(
|
|
status=JobStatus.completed,
|
|
completed_at=datetime.utcnow(),
|
|
metadata={"result": result.model_dump(mode="json")}, # Store the result in metadata
|
|
)
|
|
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
|
|
return result
|
|
except Exception as e:
|
|
job_update = JobUpdate(
|
|
status=JobStatus.failed,
|
|
completed_at=datetime.utcnow(),
|
|
metadata={"error": str(e)},
|
|
)
|
|
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
|
|
raise
|
|
|
|
def step(
|
|
self,
|
|
messages: List[MessageCreate],
|
|
chaining: bool = True,
|
|
max_chaining_steps: Optional[int] = None,
|
|
put_inner_thoughts_first: bool = True,
|
|
**kwargs,
|
|
) -> LettaUsageStatistics:
|
|
run_ids = []
|
|
|
|
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
|
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
|
|
|
messages = [
|
|
Message(
|
|
id=Message.generate_id(),
|
|
agent_id=self.agent_state.id,
|
|
role=message.role,
|
|
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
|
|
name=message.name,
|
|
model=None,
|
|
tool_calls=None,
|
|
tool_call_id=None,
|
|
group_id=self.group_id,
|
|
otid=message.otid,
|
|
)
|
|
for message in messages
|
|
]
|
|
|
|
try:
|
|
main_agent = Agent(
|
|
agent_state=self.agent_state,
|
|
interface=self.interface,
|
|
user=self.user,
|
|
)
|
|
usage_stats = main_agent.step(
|
|
messages=messages,
|
|
chaining=chaining,
|
|
max_chaining_steps=max_chaining_steps,
|
|
stream=token_streaming,
|
|
skip_verify=True,
|
|
metadata=metadata,
|
|
put_inner_thoughts_first=put_inner_thoughts_first,
|
|
)
|
|
|
|
turns_counter = None
|
|
if self.background_agents_frequency is not None and self.background_agents_frequency > 0:
|
|
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user)
|
|
|
|
if self.background_agents_frequency is None or (
|
|
turns_counter is not None and turns_counter % self.background_agents_frequency == 0
|
|
):
|
|
last_response_messages = [message for sublist in usage_stats.steps_messages for message in sublist]
|
|
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
|
|
group_id=self.group_id, last_processed_message_id=last_response_messages[-1].id, actor=self.user
|
|
)
|
|
for participant_agent_id in self.agent_ids:
|
|
try:
|
|
run_id = self._run_async_in_new_thread(
|
|
self._issue_background_task(
|
|
participant_agent_id,
|
|
last_response_messages,
|
|
chaining,
|
|
max_chaining_steps,
|
|
token_streaming,
|
|
metadata,
|
|
put_inner_thoughts_first,
|
|
last_processed_message_id,
|
|
)
|
|
)
|
|
run_ids.append(run_id)
|
|
|
|
except Exception as e:
|
|
# Handle individual task failures
|
|
print(f"Agent processing failed: {str(e)}")
|
|
raise e
|
|
|
|
except Exception as e:
|
|
raise e
|
|
finally:
|
|
self.interface.step_yield()
|
|
|
|
self.interface.step_complete()
|
|
|
|
usage_stats.run_ids = run_ids
|
|
return LettaUsageStatistics(**usage_stats.model_dump())
|