feat: Async agent loop (#1387)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, List
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
|
||||
import openai
|
||||
|
||||
from letta.schemas.letta_message import UserMessage
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
@@ -19,7 +19,8 @@ class BaseAgent(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
openai_client: openai.AsyncClient,
|
||||
# TODO: Make required once client refactor hits
|
||||
openai_client: Optional[openai.AsyncClient],
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
actor: User,
|
||||
@@ -31,14 +32,14 @@ class BaseAgent(ABC):
|
||||
self.actor = actor
|
||||
|
||||
@abstractmethod
|
||||
async def step(self, input_message: UserMessage) -> List[Message]:
|
||||
async def step(self, input_message: UserMessage, max_steps: int = 10) -> LettaResponse:
|
||||
"""
|
||||
Main execution loop for the agent.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def step_stream(self, input_message: UserMessage) -> AsyncGenerator[str, None]:
|
||||
async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main async execution loop for the agent. Implementations must yield messages as SSE events.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user