feat: added ability to disable the initial message sequence during agent creation (#1978)

This commit is contained in:
Charles Packer
2024-11-04 16:03:52 -08:00
committed by GitHub
parent 39999ce48c
commit b9f772f196
6 changed files with 136 additions and 11 deletions

View File

@@ -235,6 +235,7 @@ class Agent(BaseAgent):
# extras
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
initial_message_sequence: Optional[List[Message]] = None,
):
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
# Hold a copy of the state that was used to init the agent
@@ -294,6 +295,7 @@ class Agent(BaseAgent):
else:
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
assert self.agent_state.id is not None and self.agent_state.user_id is not None
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
@@ -306,14 +308,40 @@ class Agent(BaseAgent):
include_initial_boot_message=True,
)
# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(
if initial_message_sequence is not None:
# We always need the system prompt up front
system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict=init_messages[0],
)
# Don't use anything else in the pregen sequence, instead use the provided sequence
init_messages = [system_message_obj] + initial_message_sequence
else:
# Basic "more human than human" initial message sequence
init_messages = initialize_message_sequence(
model=self.model,
system=self.system,
memory=self.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
# Cast to Message objects
init_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg
)
)
for msg in init_messages
]
# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(msg)
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)
# Put the messages inside the message buffer

View File

@@ -376,6 +376,7 @@ class RESTClient(AbstractClient):
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
) -> AgentState:
"""Create an agent
@@ -428,9 +429,18 @@ class RESTClient(AbstractClient):
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
)
# Use model_dump_json() instead of model_dump()
# If we use model_dump(), the datetime objects will not be serialized correctly
# response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents",
data=request.model_dump_json(), # Use model_dump_json() instead of json=model_dump()
headers={"Content-Type": "application/json", **self.headers},
)
response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}")
return AgentState(**response.json())
@@ -1648,6 +1658,7 @@ class LocalClient(AbstractClient):
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
) -> AgentState:
"""Create an agent
@@ -1702,6 +1713,7 @@ class LocalClient(AbstractClient):
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
),
actor=self.user,
)

View File

@@ -1,4 +1,3 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
@@ -105,7 +104,7 @@ class AgentState(BaseAgent, validate_assignment=True):
class CreateAgent(BaseAgent):
# all optional as server can generate defaults
name: Optional[str] = Field(None, description="The name of the agent.")
message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
@@ -113,6 +112,11 @@ class CreateAgent(BaseAgent):
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
# If the client wants to make this empty, then the client can set the arg to an empty list
initial_message_sequence: Optional[List[Message]] = Field(
None, description="The initial set of messages to put in the agent's in-context memory."
)
@field_validator("name")
@classmethod

View File

@@ -21,6 +21,8 @@ class LettaBase(BaseModel):
from_attributes=True,
# throw errors if attributes are given that don't belong
extra="forbid",
# handle datetime serialization consistently across all models
# json_encoders={datetime: lambda dt: (dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt).isoformat()},
)
# def __id_prefix__(self):

View File

@@ -857,7 +857,10 @@ class SyncServer(Server):
agent_state=agent_state,
tools=tool_objs,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
first_message_verify_mono=(
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
),
initial_message_sequence=request.initial_message_sequence,
)
elif request.agent_type == AgentType.o1_agent:
agent = O1Agent(
@@ -865,7 +868,9 @@ class SyncServer(Server):
agent_state=agent_state,
tools=tool_objs,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
first_message_verify_mono=(
True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False
),
)
# rebuilding agent memory on agent create in case shared memory blocks
# were specified in the new agent's memory config. we're doing this for two reasons:

View File

@@ -8,11 +8,12 @@ import pytest
from dotenv import load_dotenv
from letta import create_client
from letta.agent import initialize_message_sequence
from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCallMessage,
@@ -28,6 +29,7 @@ from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.services.tool_manager import ToolManager
from letta.settings import model_settings
from letta.utils import get_utc_time
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -598,3 +600,75 @@ def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)
@pytest.fixture
def cleanup_agents():
created_agents = []
yield created_agents
# Cleanup will run even if test fails
for agent_id in created_agents:
try:
client.delete_agent(agent_id)
except Exception as e:
print(f"Failed to delete agent {agent_id}: {e}")
def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]):
"""Test that we can set an initial message sequence
If we pass in None, we should get a "default" message sequence
If we pass in a non-empty list, we should get that sequence
If we pass in an empty list, we should get an empty sequence
"""
# The reference initial message sequence:
reference_init_messages = initialize_message_sequence(
model=agent.llm_config.model,
system=agent.system,
memory=agent.memory,
archival_memory=None,
recall_memory=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
# system, login message, send_message test, send_message receipt
assert len(reference_init_messages) > 0
assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}"
# Test with default sequence
default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None)
cleanup_agents.append(default_agent_state.id)
assert default_agent_state.message_ids is not None
assert len(default_agent_state.message_ids) > 0
assert len(default_agent_state.message_ids) == len(
reference_init_messages
), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}"
# Test with empty sequence
empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[])
cleanup_agents.append(empty_agent_state.id)
assert empty_agent_state.message_ids is not None
assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}"
# Test with custom sequence
custom_sequence = [
Message(
role=MessageRole.user,
text="Hello, how are you?",
user_id=agent.user_id,
agent_id=agent.id,
model=agent.llm_config.model,
name=None,
tool_calls=None,
tool_call_id=None,
),
]
custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence)
cleanup_agents.append(custom_agent_state.id)
assert custom_agent_state.message_ids is not None
assert (
len(custom_agent_state.message_ids) == len(custom_sequence) + 1
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]