fix: allow initial_message_sequence to have assistant message (#1729)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Kevin Lin <kl2806@columbia.edu>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
This commit is contained in:
Shangyin Tan
2025-04-29 14:42:34 -07:00
committed by GitHub
parent ced32a0124
commit ce81c3bdcd
3 changed files with 130 additions and 11 deletions

View File

@@ -74,6 +74,7 @@ class MessageCreate(BaseModel):
role: Literal[
MessageRole.user,
MessageRole.system,
MessageRole.assistant,
] = Field(..., description="The role of the participant.")
content: Union[str, List[LettaMessageContentUnion]] = Field(
...,

View File

@@ -20,7 +20,7 @@ from letta.schemas.message import Message, MessageCreate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.tool_rule import ToolRule
from letta.schemas.user import User
from letta.system import get_initial_boot_messages, get_login_event
from letta.system import get_initial_boot_messages, get_login_event, package_function_response
from letta.tracing import trace_method
@@ -282,23 +282,76 @@ def package_initial_message_sequence(
packed_message = system.package_user_message(
user_message=message_create.content,
)
init_messages.append(
Message(
role=message_create.role,
content=[TextContent(text=packed_message)],
name=message_create.name,
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
)
)
elif message_create.role == MessageRole.system:
packed_message = system.package_system_message(
system_message=message_create.content,
)
init_messages.append(
Message(
role=message_create.role,
content=[TextContent(text=packed_message)],
name=message_create.name,
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
)
)
elif message_create.role == MessageRole.assistant:
# append tool call to send_message
import json
import uuid
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from letta.constants import DEFAULT_MESSAGE_TOOL
tool_call_id = str(uuid.uuid4())
init_messages.append(
Message(
role=MessageRole.assistant,
content=None,
name=message_create.name,
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
tool_calls=[
OpenAIToolCall(
id=tool_call_id,
type="function",
function=OpenAIFunction(name=DEFAULT_MESSAGE_TOOL, arguments=json.dumps({"message": message_create.content})),
)
],
)
)
# add tool return
function_response = package_function_response(True, "None")
init_messages.append(
Message(
role=MessageRole.tool,
content=[TextContent(text=function_response)],
name=message_create.name,
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
tool_call_id=tool_call_id,
)
)
else:
# TODO: add tool call and tool return
raise ValueError(f"Invalid message role: {message_create.role}")
init_messages.append(
Message(
role=message_create.role,
content=[TextContent(text=packed_message)],
name=message_create.name,
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
)
)
return init_messages

View File

@@ -0,0 +1,65 @@
import os
import threading
import time
import pytest
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate
def run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(
scope="module",
)
def client(request):
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create the Letta client
yield Letta(base_url=server_url, token=None)
def test_initial_sequence(client: Letta):
# create an agent
agent = client.agents.create(
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
initial_message_sequence=[
MessageCreate(
role="assistant",
content="Hello, how are you?",
),
MessageCreate(role="user", content="I'm good, and you?"),
],
)
# list messages
messages = client.agents.messages.list(agent_id=agent.id)
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content="hello assistant!",
)
],
)
assert len(messages) == 3
assert messages[0].message_type == "system_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "user_message"