feat: added ability to disable the initial message sequence during agent creation (#1978)
This commit is contained in:
@@ -235,6 +235,7 @@ class Agent(BaseAgent):
|
||||
# extras
|
||||
messages_total: Optional[int] = None, # TODO remove?
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
):
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
|
||||
# Hold a copy of the state that was used to init the agent
|
||||
@@ -294,6 +295,7 @@ class Agent(BaseAgent):
|
||||
|
||||
else:
|
||||
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
|
||||
assert self.agent_state.id is not None and self.agent_state.user_id is not None
|
||||
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
@@ -306,14 +308,40 @@ class Agent(BaseAgent):
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
|
||||
# Cast the messages to actual Message objects to be synced to the DB
|
||||
init_messages_objs = []
|
||||
for msg in init_messages:
|
||||
init_messages_objs.append(
|
||||
if initial_message_sequence is not None:
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||||
init_messages = [system_message_obj] + initial_message_sequence
|
||||
|
||||
else:
|
||||
# Basic "more human than human" initial message sequence
|
||||
init_messages = initialize_message_sequence(
|
||||
model=self.model,
|
||||
system=self.system,
|
||||
memory=self.memory,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
# Cast to Message objects
|
||||
init_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg
|
||||
)
|
||||
)
|
||||
for msg in init_messages
|
||||
]
|
||||
|
||||
# Cast the messages to actual Message objects to be synced to the DB
|
||||
init_messages_objs = []
|
||||
for msg in init_messages:
|
||||
init_messages_objs.append(msg)
|
||||
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)
|
||||
|
||||
# Put the messages inside the message buffer
|
||||
|
||||
@@ -376,6 +376,7 @@ class RESTClient(AbstractClient):
|
||||
# metadata
|
||||
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
|
||||
description: Optional[str] = None,
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
) -> AgentState:
|
||||
"""Create an agent
|
||||
|
||||
@@ -428,9 +429,18 @@ class RESTClient(AbstractClient):
|
||||
agent_type=agent_type,
|
||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||
initial_message_sequence=initial_message_sequence,
|
||||
)
|
||||
|
||||
# Use model_dump_json() instead of model_dump()
|
||||
# If we use model_dump(), the datetime objects will not be serialized correctly
|
||||
# response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
|
||||
response = requests.post(
|
||||
f"{self.base_url}/{self.api_prefix}/agents",
|
||||
data=request.model_dump_json(), # Use model_dump_json() instead of json=model_dump()
|
||||
headers={"Content-Type": "application/json", **self.headers},
|
||||
)
|
||||
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}")
|
||||
return AgentState(**response.json())
|
||||
@@ -1648,6 +1658,7 @@ class LocalClient(AbstractClient):
|
||||
# metadata
|
||||
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
|
||||
description: Optional[str] = None,
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
) -> AgentState:
|
||||
"""Create an agent
|
||||
|
||||
@@ -1702,6 +1713,7 @@ class LocalClient(AbstractClient):
|
||||
agent_type=agent_type,
|
||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||
initial_message_sequence=initial_message_sequence,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
@@ -105,7 +104,7 @@ class AgentState(BaseAgent, validate_assignment=True):
|
||||
class CreateAgent(BaseAgent):
|
||||
# all optional as server can generate defaults
|
||||
name: Optional[str] = Field(None, description="The name of the agent.")
|
||||
message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
|
||||
@@ -113,6 +112,11 @@ class CreateAgent(BaseAgent):
|
||||
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
|
||||
# If the client wants to make this empty, then the client can set the arg to an empty list
|
||||
initial_message_sequence: Optional[List[Message]] = Field(
|
||||
None, description="The initial set of messages to put in the agent's in-context memory."
|
||||
)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
|
||||
@@ -21,6 +21,8 @@ class LettaBase(BaseModel):
|
||||
from_attributes=True,
|
||||
# throw errors if attributes are given that don't belong
|
||||
extra="forbid",
|
||||
# handle datetime serialization consistently across all models
|
||||
# json_encoders={datetime: lambda dt: (dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt).isoformat()},
|
||||
)
|
||||
|
||||
# def __id_prefix__(self):
|
||||
|
||||
@@ -857,7 +857,10 @@ class SyncServer(Server):
|
||||
agent_state=agent_state,
|
||||
tools=tool_objs,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
|
||||
first_message_verify_mono=(
|
||||
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
|
||||
),
|
||||
initial_message_sequence=request.initial_message_sequence,
|
||||
)
|
||||
elif request.agent_type == AgentType.o1_agent:
|
||||
agent = O1Agent(
|
||||
@@ -865,7 +868,9 @@ class SyncServer(Server):
|
||||
agent_state=agent_state,
|
||||
tools=tool_objs,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
|
||||
first_message_verify_mono=(
|
||||
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
|
||||
),
|
||||
)
|
||||
# rebuilding agent memory on agent create in case shared memory blocks
|
||||
# were specified in the new agent's memory config. we're doing this for two reasons:
|
||||
|
||||
@@ -8,11 +8,12 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from letta import create_client
|
||||
from letta.agent import initialize_message_sequence
|
||||
from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
FunctionCallMessage,
|
||||
@@ -28,6 +29,7 @@ from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_utc_time
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
@@ -598,3 +600,75 @@ def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
client.delete_agent(agent_state2.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_agents():
|
||||
created_agents = []
|
||||
yield created_agents
|
||||
# Cleanup will run even if test fails
|
||||
for agent_id in created_agents:
|
||||
try:
|
||||
client.delete_agent(agent_id)
|
||||
except Exception as e:
|
||||
print(f"Failed to delete agent {agent_id}: {e}")
|
||||
|
||||
|
||||
def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]):
|
||||
"""Test that we can set an initial message sequence
|
||||
|
||||
If we pass in None, we should get a "default" message sequence
|
||||
If we pass in a non-empty list, we should get that sequence
|
||||
If we pass in an empty list, we should get an empty sequence
|
||||
"""
|
||||
|
||||
# The reference initial message sequence:
|
||||
reference_init_messages = initialize_message_sequence(
|
||||
model=agent.llm_config.model,
|
||||
system=agent.system,
|
||||
memory=agent.memory,
|
||||
archival_memory=None,
|
||||
recall_memory=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
|
||||
# system, login message, send_message test, send_message receipt
|
||||
assert len(reference_init_messages) > 0
|
||||
assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}"
|
||||
|
||||
# Test with default sequence
|
||||
default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None)
|
||||
cleanup_agents.append(default_agent_state.id)
|
||||
assert default_agent_state.message_ids is not None
|
||||
assert len(default_agent_state.message_ids) > 0
|
||||
assert len(default_agent_state.message_ids) == len(
|
||||
reference_init_messages
|
||||
), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}"
|
||||
|
||||
# Test with empty sequence
|
||||
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
|
||||
cleanup_agents.append(empty_agent_state.id)
|
||||
assert empty_agent_state.message_ids is not None
|
||||
assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}"
|
||||
|
||||
# Test with custom sequence
|
||||
custom_sequence = [
|
||||
Message(
|
||||
role=MessageRole.user,
|
||||
text="Hello, how are you?",
|
||||
user_id=agent.user_id,
|
||||
agent_id=agent.id,
|
||||
model=agent.llm_config.model,
|
||||
name=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
),
|
||||
]
|
||||
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
|
||||
cleanup_agents.append(custom_agent_state.id)
|
||||
assert custom_agent_state.message_ids is not None
|
||||
assert (
|
||||
len(custom_agent_state.message_ids) == len(custom_sequence) + 1
|
||||
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
|
||||
assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
|
||||
|
||||
Reference in New Issue
Block a user