diff --git a/letta/agent.py b/letta/agent.py index ebd73aa9..7c4ff97e 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -350,6 +350,8 @@ class Agent(BaseAgent): init_messages_objs = [] for msg in init_messages: init_messages_objs.append(msg) + for msg in init_messages_objs: + assert isinstance(msg, Message), f"Message object is not of type Message: {type(msg)}" assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages) # Put the messages inside the message buffer diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 8b5161eb..03a9cef9 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -9,7 +9,7 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory -from letta.schemas.message import Message +from letta.schemas.message import Message, MessageCreate from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.source import Source from letta.schemas.tool import Tool @@ -124,7 +124,7 @@ class CreateAgent(BaseAgent): # embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") # Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence # If the client wants to make this empty, then the client can set the arg to an empty list - initial_message_sequence: Optional[List[Message]] = Field( + initial_message_sequence: Optional[List[MessageCreate]] = Field( None, description="The initial set of messages to put in the agent's in-context memory." ) diff --git a/letta/server/server.py b/letta/server/server.py index 350fa45e..32c88be4 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -372,6 +372,22 @@ class SyncServer(Server): } ) + def initialize_agent(self, agent_id, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent: + """Initialize an agent from the database""" + agent_state = self.get_agent(agent_id=agent_id) + actor = self.user_manager.get_user_by_id(user_id=agent_state.user_id) + + interface = interface or self.default_interface_factory() + if agent_state.agent_type == AgentType.memgpt_agent: + agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence) + else: + assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents" + agent = O1Agent(agent_state=agent_state, interface=interface, user=actor) + + # Persist to agent + save_agent(agent, self.ms) + return agent + def load_agent(self, agent_id: str, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" agent_lock = self.per_agent_lock_manager.get_lock(agent_id) @@ -385,6 +401,9 @@ class SyncServer(Server): else: agent = O1Agent(agent_state=agent_state, interface=interface, user=actor) + # Rebuild the system prompt - may be linked to new blocks now + agent.rebuild_system_prompt() + # Persist to agent save_agent(agent, self.ms) return agent @@ -802,8 +821,6 @@ class SyncServer(Server): if not user: raise ValueError(f"cannot find user with associated client id: {user_id}") - # TODO: create the message objects (NOTE: do this after we migrate to `CreateMessage`) - # created and persist the agent state in the DB agent_state = PersistedAgentState( name=request.name, @@ -822,6 +839,31 @@ class SyncServer(Server): # this saves the agent ID and state into the DB self.ms.create_agent(agent_state) + # create the agent object + if request.initial_message_sequence: + # init_messages = [Message(user_id=user_id, agent_id=agent_state.id, role=message.role, text=message.text) for message in request.initial_message_sequence] + init_messages = [] + for message in request.initial_message_sequence: + + if message.role == MessageRole.user: + packed_message = system.package_user_message( + user_message=message.text, + ) + elif message.role == MessageRole.system: + packed_message = system.package_system_message( + system_message=message.text, + ) + else: + raise ValueError(f"Invalid message role: {message.role}") + + init_messages.append(Message(role=message.role, text=packed_message, user_id=user_id, agent_id=agent_state.id)) + # init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence] + else: + init_messages = None + + # initialize the agent (generates initial message list with system prompt) + self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages) + # Note: mappings (e.g. tags, blocks) are created after the agent is persisted # TODO: add source mappings here as well diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 96e27ab6..f49e56ab 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -27,7 +27,6 @@ from letta.schemas.letta_message import ( ) from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.services.tool_manager import ToolManager from letta.settings import model_settings, tool_settings @@ -624,66 +623,71 @@ def cleanup_agents(): print(f"Failed to delete agent {agent_id}: {e}") -## NOTE: we need to add this back once agents can also create blocks during agent creation -# def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]): -# """Test that we can set an initial message sequence -# -# If we pass in None, we should get a "default" message sequence -# If we pass in a non-empty list, we should get that sequence -# If we pass in an empty list, we should get an empty sequence -# """ -# -# # The reference initial message sequence: -# reference_init_messages = initialize_message_sequence( -# model=agent.llm_config.model, -# system=agent.system, -# memory=agent.memory, -# archival_memory=None, -# recall_memory=None, -# memory_edit_timestamp=get_utc_time(), -# include_initial_boot_message=True, -# ) -# -# # system, login message, send_message test, send_message receipt -# assert len(reference_init_messages) > 0 -# assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}" -# -# # Test with default sequence -# default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None) -# cleanup_agents.append(default_agent_state.id) -# assert default_agent_state.message_ids is not None -# assert len(default_agent_state.message_ids) > 0 -# assert len(default_agent_state.message_ids) == len( -# reference_init_messages -# ), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}" -# -# # Test with empty sequence -# empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[]) -# cleanup_agents.append(empty_agent_state.id) -# # NOTE: allowed to be None initially -# #assert empty_agent_state.message_ids is not None -# #assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}" -# -# # Test with custom sequence -# custom_sequence = [ -# Message( -# role=MessageRole.user, -# text="Hello, how are you?", -# user_id=agent.user_id, -# agent_id=agent.id, -# model=agent.llm_config.model, -# name=None, -# tool_calls=None, -# tool_call_id=None, -# ), -# ] -# custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence) -# cleanup_agents.append(custom_agent_state.id) -# assert custom_agent_state.message_ids is not None -# assert ( -# len(custom_agent_state.message_ids) == len(custom_sequence) + 1 -# ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" -# assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence] +# NOTE: we need to add this back once agents can also create blocks during agent creation +def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str]): + """Test that we can set an initial message sequence + + If we pass in None, we should get a "default" message sequence + If we pass in a non-empty list, we should get that sequence + If we pass in an empty list, we should get an empty sequence + """ + from letta.agent import initialize_message_sequence + from letta.utils import get_utc_time + + # The reference initial message sequence: + reference_init_messages = initialize_message_sequence( + model=agent.llm_config.model, + system=agent.system, + memory=agent.memory, + archival_memory=None, + recall_memory=None, + memory_edit_timestamp=get_utc_time(), + include_initial_boot_message=True, + ) + + # system, login message, send_message test, send_message receipt + assert len(reference_init_messages) > 0 + assert len(reference_init_messages) == 4, f"Expected 4 messages, got {len(reference_init_messages)}" + + # Test with default sequence + default_agent_state = client.create_agent(name="test-default-message-sequence", initial_message_sequence=None) + cleanup_agents.append(default_agent_state.id) + assert default_agent_state.message_ids is not None + assert len(default_agent_state.message_ids) > 0 + assert len(default_agent_state.message_ids) == len( + reference_init_messages + ), f"Expected {len(reference_init_messages)} messages, got {len(default_agent_state.message_ids)}" + + # Test with empty sequence + empty_agent_state = client.create_agent(name="test-empty-message-sequence", initial_message_sequence=[]) + cleanup_agents.append(empty_agent_state.id) + # NOTE: allowed to be None initially + # assert empty_agent_state.message_ids is not None + # assert len(empty_agent_state.message_ids) == 1, f"Expected 0 messages, got {len(empty_agent_state.message_ids)}" + + # Test with custom sequence + # custom_sequence = [ + # Message( + # role=MessageRole.user, + # text="Hello, how are you?", + # user_id=agent.user_id, + # agent_id=agent.id, + # model=agent.llm_config.model, + # name=None, + # tool_calls=None, + # tool_call_id=None, + # ), + # ] + custom_sequence = [{"text": "Hello, how are you?", "role": "user"}] + custom_agent_state = client.create_agent(name="test-custom-message-sequence", initial_message_sequence=custom_sequence) + cleanup_agents.append(custom_agent_state.id) + assert custom_agent_state.message_ids is not None + assert ( + len(custom_agent_state.message_ids) == len(custom_sequence) + 1 + ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" + # assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence] + # shoule be contained in second message (after system message) + assert custom_sequence[0]["text"] in client.get_in_context_messages(custom_agent_state.id)[1].text def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):