From 47ee4effd792b61faeb913f25cce7d042fa30a09 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 15 Oct 2024 13:32:13 -0700 Subject: [PATCH] refactor: make `Agent.step()` multi-step (#1884) --- letta/agent.py | 118 +++++++++++++++--- letta/client/client.py | 23 ++-- letta/main.py | 10 +- letta/schemas/agent.py | 5 +- .../routers/openai/assistants/threads.py | 2 +- letta/server/server.py | 92 ++------------ tests/test_new_client.py | 23 ++-- 7 files changed, 148 insertions(+), 125 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index c9b55dbe..76ba7378 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -11,13 +11,17 @@ from letta.agent_store.storage import StorageConnector from letta.constants import ( CLI_WARNING_PREFIX, FIRST_MESSAGE_ATTEMPTS, + FUNC_FAILED_HEARTBEAT_MESSAGE, IN_CONTEXT_MEMORY_KEYWORD, LLM_MAX_TOKENS, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, MESSAGE_SUMMARY_WARNING_FRAC, + REQ_HEARTBEAT_MESSAGE, ) +from letta.errors import LLMError from letta.interface import AgentInterface +from letta.llm_api.helpers import is_context_overflow_error from letta.llm_api.llm_api_tools import create from letta.memory import ArchivalMemory, RecallMemory, summarize_messages from letta.metadata import MetadataStore @@ -32,11 +36,15 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import ( Message as ChatCompletionMessage, ) +from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.passage import Passage from letta.schemas.tool import Tool +from letta.schemas.usage import LettaUsageStatistics from letta.system import ( + get_heartbeat, get_initial_boot_messages, get_login_event, + get_token_limit_warning, package_function_response, package_summarize_message, package_user_message, @@ -56,9 +64,6 @@ from letta.utils import ( verify_first_message_correctness, ) -from .errors import LLMError -from .llm_api.helpers import is_context_overflow_error - def compile_memory_metadata_block( memory_edit_timestamp: datetime.datetime, @@ -202,7 +207,7 @@ class BaseAgent(ABC): def step( self, messages: Union[Message, List[Message]], - ) -> AgentStepResponse: + ) -> LettaUsageStatistics: """ Top-level event message handler for the agent. """ @@ -721,18 +726,105 @@ class Agent(BaseAgent): return messages, heartbeat_request, function_failed def step( + self, + messages: Union[Message, List[Message]], + # additional args + chaining: bool = True, + max_chaining_steps: Optional[int] = None, + ms: Optional[MetadataStore] = None, + **kwargs, + ) -> LettaUsageStatistics: + """Run Agent.step in a loop, handling chaining via heartbeat requests and function failures""" + # assert ms is not None, "MetadataStore is required" + + next_input_message = messages if isinstance(messages, list) else [messages] + counter = 0 + total_usage = UsageStatistics() + step_count = 0 + while True: + kwargs["ms"] = ms + kwargs["first_message"] = False + step_response = self.inner_step( + messages=next_input_message, + **kwargs, + ) + step_response.messages + heartbeat_request = step_response.heartbeat_request + function_failed = step_response.function_failed + token_warning = step_response.in_context_memory_warning + usage = step_response.usage + + step_count += 1 + total_usage += usage + counter += 1 + self.interface.step_complete() + + # logger.debug("Saving agent state") + # save updated state + if ms: + save_agent(self, ms) + + # Chain stops + if not chaining: + printd("No chaining, stopping after one step") + break + elif max_chaining_steps is not None and counter > max_chaining_steps: + printd(f"Hit max chaining steps, stopping after {counter} steps") + break + # Chain handlers + elif token_warning: + assert self.agent_state.user_id is not None + next_input_message = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.user_id, + model=self.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": get_token_limit_warning(), + }, + ) + continue # always chain + elif function_failed: + assert self.agent_state.user_id is not None + next_input_message = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.user_id, + model=self.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE), + }, + ) + continue # always chain + elif heartbeat_request: + assert self.agent_state.user_id is not None + next_input_message = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.user_id, + model=self.model, + openai_message_dict={ + "role": "user", # TODO: change to system? + "content": get_heartbeat(REQ_HEARTBEAT_MESSAGE), + }, + ) + continue # always chain + # Letta no-op / yield + else: + break + + return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + + def inner_step( self, messages: Union[Message, List[Message]], first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, - return_dicts: bool = True, - # recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field stream: bool = False, # TODO move to config? inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT, ms: Optional[MetadataStore] = None, ) -> AgentStepResponse: - """Top-level event message handler for the Letta agent""" + """Runs a single step in the agent loop (generates at most one LLM call)""" try: @@ -834,13 +926,12 @@ class Agent(BaseAgent): ) self._append_to_messages(all_new_messages) - messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages # update state after each step self.update_state() return AgentStepResponse( - messages=messages_to_return, + messages=all_new_messages, heartbeat_request=heartbeat_request, function_failed=function_failed, in_context_memory_warning=active_memory_warning, @@ -856,15 +947,12 @@ class Agent(BaseAgent): self.summarize_messages_inplace() # Try step again - return self.step( + return self.inner_step( messages=messages, first_message=first_message, first_message_retry_limit=first_message_retry_limit, skip_verify=skip_verify, - return_dicts=return_dicts, - # recreate_message_timestamp=recreate_message_timestamp, stream=stream, - # timestamp=timestamp, inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option, ms=ms, ) @@ -905,7 +993,7 @@ class Agent(BaseAgent): # created_at=timestamp, ) - return self.step(messages=[user_message], **kwargs) + return self.inner_step(messages=[user_message], **kwargs) def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True): assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})" @@ -1326,7 +1414,7 @@ class Agent(BaseAgent): self.pop_until_user() user_message = self.pop_message(count=1)[0] assert user_message.text is not None, "User message text is None" - step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False) + step_response = self.step_user_message(user_message_str=user_message.text) messages = step_response.messages assert messages is not None diff --git a/letta/client/client.py b/letta/client/client.py index 35a9a184..40f2275d 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -747,8 +747,9 @@ class RESTClient(AbstractClient): # simplify messages if not include_full_message: messages = [] - for message in response.messages: - messages += message.to_letta_message() + for m in response.messages: + assert isinstance(m, Message) + messages += m.to_letta_message() response.messages = messages return response @@ -1677,7 +1678,7 @@ class LocalClient(AbstractClient): self.interface.clear() return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id) - def get_agent_id(self, agent_name: str) -> AgentState: + def get_agent_id(self, agent_name: str) -> Optional[str]: """ Get the ID of an agent by name (names are unique per user) @@ -1767,6 +1768,7 @@ class LocalClient(AbstractClient): self, message: str, role: str, + name: Optional[str] = None, agent_id: Optional[str] = None, agent_name: Optional[str] = None, stream_steps: bool = False, @@ -1790,19 +1792,18 @@ class LocalClient(AbstractClient): # lookup agent by name assert agent_name, f"Either agent_id or agent_name must be provided" agent_id = self.get_agent_id(agent_name=agent_name) - - agent_state = self.get_agent(agent_id=agent_id) + assert agent_id, f"Agent with name {agent_name} not found" if stream_steps or stream_tokens: # TODO: implement streaming with stream=True/False raise NotImplementedError self.interface.clear() - if role == "system": - usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message) - elif role == "user": - usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message) - else: - raise ValueError(f"Role {role} not supported") + + usage = self.server.send_messages( + user_id=self.user_id, + agent_id=agent_id, + messages=[MessageCreate(role=MessageRole(role), text=message, name=name)], + ) # auto-save if self.auto_save: diff --git a/letta/main.py b/letta/main.py index b084333c..f16eb895 100644 --- a/letta/main.py +++ b/letta/main.py @@ -361,8 +361,10 @@ def run_agent_loop( skip_next_user_input = False def process_agent_step(user_message, no_verify): + # TODO(charles): update to use agent.step() instead of inner_step() + if user_message is None: - step_response = letta_agent.step( + step_response = letta_agent.inner_step( messages=[], first_message=False, skip_verify=no_verify, @@ -402,15 +404,15 @@ def run_agent_loop( while True: try: if strip_ui: - new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break else: if stream: # Don't display the "Thinking..." if streaming - new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) else: with console.status("[bold cyan]Thinking...") as status: - new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) + _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) break except KeyboardInterrupt: print("User interrupt occurred.") diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 8c40d31f..367765fd 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -1,7 +1,7 @@ import uuid from datetime import datetime from enum import Enum -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from pydantic import BaseModel, Field, field_validator @@ -121,8 +121,7 @@ class UpdateAgentState(BaseAgent): class AgentStepResponse(BaseModel): - # TODO remove support for list of dicts - messages: Union[List[Message], List[dict]] = Field(..., description="The messages generated during the agent's step.") + messages: List[Message] = Field(..., description="The messages generated during the agent's step.") heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).") function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.") in_context_memory_warning: bool = Field( diff --git a/letta/server/rest_api/routers/openai/assistants/threads.py b/letta/server/rest_api/routers/openai/assistants/threads.py index 43d7235f..af63e7b7 100644 --- a/letta/server/rest_api/routers/openai/assistants/threads.py +++ b/letta/server/rest_api/routers/openai/assistants/threads.py @@ -248,7 +248,7 @@ def create_run( agent_id = thread_id # TODO: override preset of agent with request.assistant_id agent = server._get_or_load_agent(agent_id=agent_id) - agent.step(user_message=None) # already has messages added + agent.inner_step(messages=[]) # already has messages added run_id = str(uuid.uuid4()) create_time = int(get_utc_time().timestamp()) return OpenAIRun( diff --git a/letta/server/server.py b/letta/server/server.py index 748a14c1..16df2be9 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -74,7 +74,6 @@ from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage -from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.organization import Organization, OrganizationCreate from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate @@ -411,6 +410,7 @@ class SyncServer(Server): raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}") logger.debug(f"Got input messages: {input_messages}") + letta_agent = None try: # Get the agent object (loaded in memory) @@ -422,83 +422,14 @@ class SyncServer(Server): token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False logger.debug(f"Starting agent step") - no_verify = True - next_input_message = input_messages - counter = 0 - total_usage = UsageStatistics() - step_count = 0 - while True: - step_response = letta_agent.step( - messages=next_input_message, - first_message=False, - skip_verify=no_verify, - return_dicts=False, - stream=token_streaming, - # timestamp=timestamp, - ms=self.ms, - ) - step_response.messages - heartbeat_request = step_response.heartbeat_request - function_failed = step_response.function_failed - token_warning = step_response.in_context_memory_warning - usage = step_response.usage - - step_count += 1 - total_usage += usage - counter += 1 - letta_agent.interface.step_complete() - - logger.debug("Saving agent state") - # save updated state - save_agent(letta_agent, self.ms) - - # Chain stops - if not self.chaining: - logger.debug("No chaining, stopping after one step") - break - elif self.max_chaining_steps is not None and counter > self.max_chaining_steps: - logger.debug(f"Hit max chaining steps, stopping after {counter} steps") - break - # Chain handlers - elif token_warning: - assert letta_agent.agent_state.user_id is not None - next_input_message = Message.dict_to_message( - agent_id=letta_agent.agent_state.id, - user_id=letta_agent.agent_state.user_id, - model=letta_agent.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": system.get_token_limit_warning(), - }, - ) - continue # always chain - elif function_failed: - assert letta_agent.agent_state.user_id is not None - next_input_message = Message.dict_to_message( - agent_id=letta_agent.agent_state.id, - user_id=letta_agent.agent_state.user_id, - model=letta_agent.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE), - }, - ) - continue # always chain - elif heartbeat_request: - assert letta_agent.agent_state.user_id is not None - next_input_message = Message.dict_to_message( - agent_id=letta_agent.agent_state.id, - user_id=letta_agent.agent_state.user_id, - model=letta_agent.model, - openai_message_dict={ - "role": "user", # TODO: change to system? - "content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE), - }, - ) - continue # always chain - # Letta no-op / yield - else: - break + usage_stats = letta_agent.step( + messages=input_messages, + chaining=self.chaining, + max_chaining_steps=self.max_chaining_steps, + stream=token_streaming, + ms=self.ms, + skip_verify=True, + ) except Exception as e: logger.error(f"Error in server._step: {e}") @@ -506,9 +437,10 @@ class SyncServer(Server): raise finally: logger.debug("Calling step_yield()") - letta_agent.interface.step_yield() + if letta_agent: + letta_agent.interface.step_yield() - return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) + return usage_stats def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: """Process a CLI command""" diff --git a/tests/test_new_client.py b/tests/test_new_client.py index 395b6020..f4bb1ed6 100644 --- a/tests/test_new_client.py +++ b/tests/test_new_client.py @@ -4,6 +4,7 @@ import pytest from letta import create_client from letta.client.client import LocalClient, RESTClient +from letta.schemas.agent import AgentState from letta.schemas.block import Block from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig @@ -113,7 +114,7 @@ def test_agent(client: Union[LocalClient, RESTClient]): client.delete_agent(agent_state_test.id) -def test_agent_with_shared_blocks(client): +def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]): persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id) human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id) existing_non_template_blocks = [persona_block, human_block] @@ -164,7 +165,7 @@ def test_agent_with_shared_blocks(client): client.delete_agent(second_agent_state_test.id) -def test_memory(client, agent): +def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState): # get agent memory original_memory = client.get_in_context_memory(agent.id) assert original_memory is not None @@ -177,7 +178,7 @@ def test_memory(client, agent): assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated -def test_archival_memory(client, agent): +def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState): """Test functions for interacting with archival memory store""" # add archival memory @@ -192,12 +193,12 @@ def test_archival_memory(client, agent): client.delete_archival_memory(agent.id, passage.id) -def test_recall_memory(client, agent): +def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState): """Test functions for interacting with recall memory store""" # send message to the agent message_str = "Hello" - client.send_message(message_str, "user", agent.id) + client.send_message(message=message_str, role="user", agent_id=agent.id) # list messages messages = client.get_messages(agent.id) @@ -216,7 +217,7 @@ def test_recall_memory(client, agent): assert exists -def test_tools(client): +def test_tools(client: Union[LocalClient, RESTClient]): def print_tool(message: str): """ A tool to print a message @@ -265,7 +266,7 @@ def test_tools(client): # assert len(client.list_tools()) == orig_tool_length -def test_tools_from_composio_basic(client): +def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]): from composio_langchain import Action from letta.schemas.tool import Tool @@ -286,7 +287,7 @@ def test_tools_from_composio_basic(client): # The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable -def test_tools_from_crewai(client): +def test_tools_from_crewai(client: Union[LocalClient, RESTClient]): # create crewAI tool from crewai_tools import ScrapeWebsiteTool @@ -323,7 +324,7 @@ def test_tools_from_crewai(client): assert expected_content in func(website_url=simple_webpage_url) -def test_tools_from_crewai_with_params(client): +def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]): # create crewAI tool from crewai_tools import ScrapeWebsiteTool @@ -357,7 +358,7 @@ def test_tools_from_crewai_with_params(client): assert expected_content in func() -def test_tools_from_langchain(client): +def test_tools_from_langchain(client: Union[LocalClient, RESTClient]): # create langchain tool from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper @@ -391,7 +392,7 @@ def test_tools_from_langchain(client): assert expected_content in func(query="Albert Einstein") -def test_tool_creation_langchain_missing_imports(client): +def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, RESTClient]): # create langchain tool from langchain_community.tools import WikipediaQueryRun from langchain_community.utilities import WikipediaAPIWrapper