From cd508cbb6b7eb2834ecfeac48a18adf2cdc4199a Mon Sep 17 00:00:00 2001 From: Shangyin Tan Date: Tue, 29 Apr 2025 14:42:34 -0700 Subject: [PATCH] fix: allow `initial_message_sequence` to have assistant message (#1729) Co-authored-by: Sarah Wooders Co-authored-by: Kevin Lin Co-authored-by: Kevin Lin --- letta/schemas/message.py | 1 + .../services/helpers/agent_manager_helper.py | 75 ++++++++++++++++--- tests/integration_test_initial_sequence.py | 65 ++++++++++++++++ 3 files changed, 130 insertions(+), 11 deletions(-) create mode 100644 tests/integration_test_initial_sequence.py diff --git a/letta/schemas/message.py b/letta/schemas/message.py index fa075869..b3cfc255 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -74,6 +74,7 @@ class MessageCreate(BaseModel): role: Literal[ MessageRole.user, MessageRole.system, + MessageRole.assistant, ] = Field(..., description="The role of the participant.") content: Union[str, List[LettaMessageContentUnion]] = Field( ..., diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 201550f0..236d14c2 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -20,7 +20,7 @@ from letta.schemas.message import Message, MessageCreate from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.tool_rule import ToolRule from letta.schemas.user import User -from letta.system import get_initial_boot_messages, get_login_event +from letta.system import get_initial_boot_messages, get_login_event, package_function_response from letta.tracing import trace_method @@ -282,23 +282,76 @@ def package_initial_message_sequence( packed_message = system.package_user_message( user_message=message_create.content, ) + init_messages.append( + Message( + role=message_create.role, + content=[TextContent(text=packed_message)], + name=message_create.name, + organization_id=actor.organization_id, + agent_id=agent_id, + model=model, + ) + ) elif message_create.role == MessageRole.system: packed_message = system.package_system_message( system_message=message_create.content, ) + init_messages.append( + Message( + role=message_create.role, + content=[TextContent(text=packed_message)], + name=message_create.name, + organization_id=actor.organization_id, + agent_id=agent_id, + model=model, + ) + ) + elif message_create.role == MessageRole.assistant: + # append tool call to send_message + import json + import uuid + + from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall + from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction + + from letta.constants import DEFAULT_MESSAGE_TOOL + + tool_call_id = str(uuid.uuid4()) + init_messages.append( + Message( + role=MessageRole.assistant, + content=None, + name=message_create.name, + organization_id=actor.organization_id, + agent_id=agent_id, + model=model, + tool_calls=[ + OpenAIToolCall( + id=tool_call_id, + type="function", + function=OpenAIFunction(name=DEFAULT_MESSAGE_TOOL, arguments=json.dumps({"message": message_create.content})), + ) + ], + ) + ) + + # add tool return + function_response = package_function_response(True, "None") + init_messages.append( + Message( + role=MessageRole.tool, + content=[TextContent(text=function_response)], + name=message_create.name, + organization_id=actor.organization_id, + agent_id=agent_id, + model=model, + tool_call_id=tool_call_id, + ) + ) else: + # TODO: add tool call and tool return raise ValueError(f"Invalid message role: {message_create.role}") - init_messages.append( - Message( - role=message_create.role, - content=[TextContent(text=packed_message)], - name=message_create.name, - organization_id=actor.organization_id, - agent_id=agent_id, - model=model, - ) - ) return init_messages diff --git a/tests/integration_test_initial_sequence.py b/tests/integration_test_initial_sequence.py new file mode 100644 index 00000000..71449171 --- /dev/null +++ b/tests/integration_test_initial_sequence.py @@ -0,0 +1,65 @@ +import os +import threading +import time + +import pytest +from dotenv import load_dotenv +from letta_client import Letta, MessageCreate + + +def run_server(): + load_dotenv() + + from letta.server.rest_api.app import start_server + + print("Starting server...") + start_server(debug=True) + + +@pytest.fixture( + scope="module", +) +def client(request): + # Get URL from environment or start server + server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283") + if not os.getenv("LETTA_SERVER_URL"): + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) + + # create the Letta client + yield Letta(base_url=server_url, token=None) + + +def test_initial_sequence(client: Letta): + # create an agent + agent = client.agents.create( + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + initial_message_sequence=[ + MessageCreate( + role="assistant", + content="Hello, how are you?", + ), + MessageCreate(role="user", content="I'm good, and you?"), + ], + ) + + # list messages + messages = client.agents.messages.list(agent_id=agent.id) + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content="hello assistant!", + ) + ], + ) + assert len(messages) == 3 + assert messages[0].message_type == "system_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "user_message"