chore: Deprecate O1 Agent (#573)
This commit is contained in:
1
.github/workflows/integration_tests.yml
vendored
1
.github/workflows/integration_tests.yml
vendored
@@ -29,7 +29,6 @@ jobs:
|
||||
- "integration_test_tool_execution_sandbox.py"
|
||||
- "integration_test_offline_memory_agent.py"
|
||||
- "integration_test_agent_tool_graph.py"
|
||||
- "integration_test_o1_agent.py"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
|
||||
@@ -16,7 +16,6 @@ from letta.constants import (
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
O1_BASE_TOOLS,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
@@ -212,7 +211,7 @@ class Agent(BaseAgent):
|
||||
# TODO: This is NO BUENO
|
||||
# TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop
|
||||
# TODO: We will have probably have to match the function strings exactly for safety
|
||||
if function_name in BASE_TOOLS or function_name in O1_BASE_TOOLS:
|
||||
if function_name in BASE_TOOLS:
|
||||
# base tools are allowed to access the `Agent` object and run on the database
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
|
||||
@@ -42,7 +42,6 @@ DEFAULT_PRESET = "memgpt_chat"
|
||||
# Base tools that cannot be edited, as they access agent state directly
|
||||
# Note that we don't include "conversation_search_date" for now
|
||||
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
|
||||
O1_BASE_TOOLS = ["send_thinking_message", "send_final_message"]
|
||||
# Base memory tools CAN be edited, and are added by default by the server
|
||||
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
|
||||
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.interface import AgentInterface
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
|
||||
|
||||
def send_thinking_message(self: "Agent", message: str) -> Optional[str]:
|
||||
"""
|
||||
Sends a thinking message so that the model can reason out loud before responding.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.interface.internal_monologue(message)
|
||||
return None
|
||||
|
||||
|
||||
def send_final_message(self: "Agent", message: str) -> Optional[str]:
|
||||
"""
|
||||
Sends a final message to the human user after thinking for a while.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.interface.internal_monologue(message)
|
||||
return None
|
||||
|
||||
|
||||
class O1Agent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User,
|
||||
max_thinking_steps: int = 10,
|
||||
first_message_verify_mono: bool = False,
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.max_thinking_steps = max_thinking_steps
|
||||
self.first_message_verify_mono = first_message_verify_mono
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Run Agent.inner_step in a loop, terminate when final thinking message is sent or max_thinking_steps is reached"""
|
||||
# assert ms is not None, "MetadataStore is required"
|
||||
next_input_message = messages if isinstance(messages, list) else [messages]
|
||||
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
while step_count < self.max_thinking_steps:
|
||||
if counter > 0:
|
||||
next_input_message = []
|
||||
|
||||
kwargs["first_message"] = False
|
||||
step_response = self.inner_step(
|
||||
messages=next_input_message,
|
||||
**kwargs,
|
||||
)
|
||||
usage = step_response.usage
|
||||
step_count += 1
|
||||
total_usage += usage
|
||||
counter += 1
|
||||
self.interface.step_complete()
|
||||
# check if it is final thinking message
|
||||
if step_response.messages[-1].name == "send_final_message":
|
||||
break
|
||||
save_agent(self)
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
@@ -25,7 +25,6 @@ class AgentType(str, Enum):
|
||||
|
||||
memgpt_agent = "memgpt_agent"
|
||||
split_thread_agent = "split_thread_agent"
|
||||
o1_agent = "o1_agent"
|
||||
offline_memory_agent = "offline_memory_agent"
|
||||
chat_only_agent = "chat_only_agent"
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.interface import AgentInterface # abstract
|
||||
from letta.interface import CLIInterface # for printing to terminal
|
||||
from letta.log import get_logger
|
||||
from letta.o1_agent import O1Agent
|
||||
from letta.offline_memory_agent import OfflineMemoryAgent
|
||||
from letta.orm import Base
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -390,8 +389,6 @@ class SyncServer(Server):
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
agent = Agent(agent_state=agent_state, interface=interface, user=actor)
|
||||
elif agent_state.agent_type == AgentType.o1_agent:
|
||||
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
|
||||
elif agent_state.agent_type == AgentType.offline_memory_agent:
|
||||
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
|
||||
elif agent_state.agent_type == AgentType.chat_only_agent:
|
||||
|
||||
@@ -89,8 +89,6 @@ def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
|
||||
# TODO: don't hardcode
|
||||
if agent_type == AgentType.memgpt_agent:
|
||||
system = gpt_system.get_system_text("memgpt_chat")
|
||||
elif agent_type == AgentType.o1_agent:
|
||||
system = gpt_system.get_system_text("memgpt_modified_o1")
|
||||
elif agent_type == AgentType.offline_memory_agent:
|
||||
system = gpt_system.get_system_text("memgpt_offline_memory")
|
||||
elif agent_type == AgentType.chat_only_agent:
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
from letta.client.client import create_client
|
||||
from letta.constants import DEFAULT_HUMAN
|
||||
from letta.o1_agent import send_final_message, send_thinking_message
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.utils import get_human_text, get_persona_text
|
||||
|
||||
|
||||
def test_o1_agent():
|
||||
client = create_client()
|
||||
assert client is not None
|
||||
|
||||
thinking_tool = client.create_or_update_tool(send_thinking_message)
|
||||
final_tool = client.create_or_update_tool(send_final_message)
|
||||
|
||||
agent_state = client.create_agent(
|
||||
agent_type=AgentType.o1_agent,
|
||||
tool_ids=[thinking_tool.id, final_tool.id],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
memory=ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text("o1_persona")),
|
||||
)
|
||||
agent = client.get_agent(agent_id=agent_state.id)
|
||||
assert agent is not None
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="9.9 or 9.11, which is a larger number?")
|
||||
assert response is not None
|
||||
assert len(response.messages) > 3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_o1_agent()
|
||||
Reference in New Issue
Block a user