diff --git a/letta/agent.py b/letta/agent.py index a5297723..724ca5b5 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -235,6 +235,7 @@ class Agent(BaseAgent): # extras messages_total: Optional[int] = None, # TODO remove? first_message_verify_mono: bool = True, # TODO move to config? + initial_message_sequence: Optional[List[Message]] = None, ): assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}" # Hold a copy of the state that was used to init the agent @@ -294,6 +295,7 @@ class Agent(BaseAgent): else: printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}") + assert self.agent_state.id is not None and self.agent_state.user_id is not None # Generate a sequence of initial messages to put in the buffer init_messages = initialize_message_sequence( @@ -306,14 +308,40 @@ class Agent(BaseAgent): include_initial_boot_message=True, ) - # Cast the messages to actual Message objects to be synced to the DB - init_messages_objs = [] - for msg in init_messages: - init_messages_objs.append( + if initial_message_sequence is not None: + # We always need the system prompt up front + system_message_obj = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.user_id, + model=self.model, + openai_message_dict=init_messages[0], + ) + # Don't use anything else in the pregen sequence, instead use the provided sequence + init_messages = [system_message_obj] + initial_message_sequence + + else: + # Basic "more human than human" initial message sequence + init_messages = initialize_message_sequence( + model=self.model, + system=self.system, + memory=self.memory, + archival_memory=None, + recall_memory=None, + memory_edit_timestamp=get_utc_time(), + include_initial_boot_message=True, + ) + # Cast to Message objects + init_messages = [ Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg ) - ) + for msg in init_messages + ] + + # Cast the messages to actual Message objects to be synced to the DB + init_messages_objs = [] + for msg in init_messages: + init_messages_objs.append(msg) assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages) # Put the messages inside the message buffer diff --git a/letta/client/client.py b/letta/client/client.py index f4eb2211..3dd1a814 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -376,6 +376,7 @@ class RESTClient(AbstractClient): # metadata metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, description: Optional[str] = None, + initial_message_sequence: Optional[List[Message]] = None, ) -> AgentState: """Create an agent @@ -428,9 +429,18 @@ class RESTClient(AbstractClient): agent_type=agent_type, llm_config=llm_config if llm_config else self._default_llm_config, embedding_config=embedding_config if embedding_config else self._default_embedding_config, + initial_message_sequence=initial_message_sequence, + ) + + # Use model_dump_json() instead of model_dump() + # If we use model_dump(), the datetime objects will not be serialized correctly + # response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers) + response = requests.post( + f"{self.base_url}/{self.api_prefix}/agents", + data=request.model_dump_json(), # Use model_dump_json() instead of json=model_dump() + headers={"Content-Type": "application/json", **self.headers}, ) - response = requests.post(f"{self.base_url}/{self.api_prefix}/agents", json=request.model_dump(), headers=self.headers) if response.status_code != 200: raise ValueError(f"Status {response.status_code} - Failed to create agent: {response.text}") return AgentState(**response.json()) @@ -1648,6 +1658,7 @@ class LocalClient(AbstractClient): # metadata metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, description: Optional[str] = None, + initial_message_sequence: Optional[List[Message]] = None, ) -> AgentState: """Create an agent @@ -1702,6 +1713,7 @@ class LocalClient(AbstractClient): agent_type=agent_type, llm_config=llm_config if llm_config else self._default_llm_config, embedding_config=embedding_config if embedding_config else self._default_embedding_config, + initial_message_sequence=initial_message_sequence, ), actor=self.user, ) diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index f0099dc1..92661024 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -1,4 +1,3 @@ -import uuid from datetime import datetime from enum import Enum from typing import Dict, List, Optional @@ -105,7 +104,7 @@ class AgentState(BaseAgent, validate_assignment=True): class CreateAgent(BaseAgent): # all optional as server can generate defaults name: Optional[str] = Field(None, description="The name of the agent.") - message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.") + message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.") @@ -113,6 +112,11 @@ class CreateAgent(BaseAgent): agent_type: Optional[AgentType] = Field(None, description="The type of agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") + # Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence + # If the client wants to make this empty, then the client can set the arg to an empty list + initial_message_sequence: Optional[List[Message]] = Field( + None, description="The initial set of messages to put in the agent's in-context memory." + ) @field_validator("name") @classmethod diff --git a/letta/schemas/letta_base.py b/letta/schemas/letta_base.py index f2b2b09f..3855a3ab 100644 --- a/letta/schemas/letta_base.py +++ b/letta/schemas/letta_base.py @@ -21,6 +21,8 @@ class LettaBase(BaseModel): from_attributes=True, # throw errors if attributes are given that don't belong extra="forbid", + # handle datetime serialization consistently across all models + # json_encoders={datetime: lambda dt: (dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt).isoformat()}, ) # def __id_prefix__(self): diff --git a/letta/server/server.py b/letta/server/server.py index 9a4b318e..2fbd7806 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -857,7 +857,10 @@ class SyncServer(Server): agent_state=agent_state, tools=tool_objs, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False, + first_message_verify_mono=( + True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False + ), + initial_message_sequence=request.initial_message_sequence, ) elif request.agent_type == AgentType.o1_agent: agent = O1Agent( @@ -865,7 +868,9 @@ class SyncServer(Server): agent_state=agent_state, tools=tool_objs, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False, + first_message_verify_mono=( + True if (llm_config and llm_config.model is not None and "gpt-4" in llm_config.model) else False + ), ) # rebuilding agent memory on agent create in case shared memory blocks # were specified in the new agent's memory config. we're doing this for two reasons: diff --git a/tests/test_client.py b/tests/test_client.py index 4cab823d..f1a4d090 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,11 +8,12 @@ import pytest from dotenv import load_dotenv from letta import create_client +from letta.agent import initialize_message_sequence from letta.client.client import LocalClient, RESTClient from letta.constants import DEFAULT_PRESET from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageStreamStatus +from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, FunctionCallMessage, @@ -28,6 +29,7 @@ from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.services.tool_manager import ToolManager from letta.settings import model_settings +from letta.utils import get_utc_time from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -598,3 +600,75 @@ def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState # cleanup client.delete_agent(agent_state1.id) client.delete_agent(agent_state2.id) + + +@pytest.fixture +def cleanup_agents(): + created_agents = [] + yield created_agents + # Cleanup will run even if test fails + for agent_id in created_agents: + try: + client.delete_agent(agent_id) + except Exception as e: + print(f"Failed to delete agent {agent_id}: {e}") + + +def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]): + """Test that we can set an initial message sequence + + If we pass in None, we should get a "default" message sequence + If we pass in a non-empty list, we should get that sequence + If we pass in an empty list, we should get an empty sequence + """ + + # The reference initial message sequence: + reference_init_messages = initialize_message_sequence( + model=agent.llm_config.model, + system=agent.system, + memory=agent.memory, + archival_memory=None, + recall_memory=None, + memory_edit_timestamp=get_utc_time(), + include_initial_boot_message=True, + ) + + # system, login message, send_message test, send_message receipt + assert len(reference_init_messages) > 0 + assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}" + + # Test with default sequence + default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None) + cleanup_agents.append(default_agent_state.id) + assert default_agent_state.message_ids is not None + assert len(default_agent_state.message_ids) > 0 + assert len(default_agent_state.message_ids) == len( + reference_init_messages + ), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}" + + # Test with empty sequence + empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[]) + cleanup_agents.append(empty_agent_state.id) + assert empty_agent_state.message_ids is not None + assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}" + + # Test with custom sequence + custom_sequence = [ + Message( + role=MessageRole.user, + text="Hello, how are you?", + user_id=agent.user_id, + agent_id=agent.id, + model=agent.llm_config.model, + name=None, + tool_calls=None, + tool_call_id=None, + ), + ] + custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence) + cleanup_agents.append(custom_agent_state.id) + assert custom_agent_state.message_ids is not None + assert ( + len(custom_agent_state.message_ids) == len(custom_sequence) + 1 + ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" + assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]