fix: allow initial_message_sequence to have assistant message (#1729)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: Kevin Lin <kl2806@columbia.edu> Co-authored-by: Kevin Lin <klin5061@gmail.com>
This commit is contained in:
@@ -74,6 +74,7 @@ class MessageCreate(BaseModel):
|
||||
role: Literal[
|
||||
MessageRole.user,
|
||||
MessageRole.system,
|
||||
MessageRole.assistant,
|
||||
] = Field(..., description="The role of the participant.")
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
|
||||
@@ -20,7 +20,7 @@ from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.system import get_initial_boot_messages, get_login_event
|
||||
from letta.system import get_initial_boot_messages, get_login_event, package_function_response
|
||||
from letta.tracing import trace_method
|
||||
|
||||
|
||||
@@ -282,23 +282,76 @@ def package_initial_message_sequence(
|
||||
packed_message = system.package_user_message(
|
||||
user_message=message_create.content,
|
||||
)
|
||||
init_messages.append(
|
||||
Message(
|
||||
role=message_create.role,
|
||||
content=[TextContent(text=packed_message)],
|
||||
name=message_create.name,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
elif message_create.role == MessageRole.system:
|
||||
packed_message = system.package_system_message(
|
||||
system_message=message_create.content,
|
||||
)
|
||||
init_messages.append(
|
||||
Message(
|
||||
role=message_create.role,
|
||||
content=[TextContent(text=packed_message)],
|
||||
name=message_create.name,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
elif message_create.role == MessageRole.assistant:
|
||||
# append tool call to send_message
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
||||
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
init_messages.append(
|
||||
Message(
|
||||
role=MessageRole.assistant,
|
||||
content=None,
|
||||
name=message_create.name,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_calls=[
|
||||
OpenAIToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=OpenAIFunction(name=DEFAULT_MESSAGE_TOOL, arguments=json.dumps({"message": message_create.content})),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# add tool return
|
||||
function_response = package_function_response(True, "None")
|
||||
init_messages.append(
|
||||
Message(
|
||||
role=MessageRole.tool,
|
||||
content=[TextContent(text=function_response)],
|
||||
name=message_create.name,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# TODO: add tool call and tool return
|
||||
raise ValueError(f"Invalid message role: {message_create.role}")
|
||||
|
||||
init_messages.append(
|
||||
Message(
|
||||
role=message_create.role,
|
||||
content=[TextContent(text=packed_message)],
|
||||
name=message_create.name,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
return init_messages
|
||||
|
||||
|
||||
|
||||
65
tests/integration_test_initial_sequence.py
Normal file
65
tests/integration_test_initial_sequence.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
# Get URL from environment or start server
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
|
||||
# create the Letta client
|
||||
yield Letta(base_url=server_url, token=None)
|
||||
|
||||
|
||||
def test_initial_sequence(client: Letta):
|
||||
# create an agent
|
||||
agent = client.agents.create(
|
||||
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
|
||||
model="letta/letta-free",
|
||||
embedding="letta/letta-free",
|
||||
initial_message_sequence=[
|
||||
MessageCreate(
|
||||
role="assistant",
|
||||
content="Hello, how are you?",
|
||||
),
|
||||
MessageCreate(role="user", content="I'm good, and you?"),
|
||||
],
|
||||
)
|
||||
|
||||
# list messages
|
||||
messages = client.agents.messages.list(agent_id=agent.id)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="hello assistant!",
|
||||
)
|
||||
],
|
||||
)
|
||||
assert len(messages) == 3
|
||||
assert messages[0].message_type == "system_message"
|
||||
assert messages[1].message_type == "assistant_message"
|
||||
assert messages[2].message_type == "user_message"
|
||||
Reference in New Issue
Block a user