diff --git a/letta/client/client.py b/letta/client/client.py index f8005432..8e1cf629 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -791,7 +791,7 @@ class RESTClient(AbstractClient): name: Optional[str] = None, stream_steps: bool = False, stream_tokens: bool = False, - include_full_message: Optional[bool] = False, + include_full_message: bool = False, ) -> Union[LettaResponse, Generator[LettaStreamingResponse, None, None]]: """ Send a message to an agent @@ -812,7 +812,12 @@ class RESTClient(AbstractClient): # TODO: figure out how to handle stream_steps and stream_tokens # When streaming steps is True, stream_tokens must be False - request = LettaRequest(messages=messages, stream_steps=stream_steps, stream_tokens=stream_tokens, return_message_object=True) + request = LettaRequest( + messages=messages, + stream_steps=stream_steps, + stream_tokens=stream_tokens, + return_message_object=include_full_message, + ) if stream_tokens or stream_steps: from letta.client.streaming import _sse_post @@ -827,12 +832,12 @@ class RESTClient(AbstractClient): response = LettaResponse(**response.json()) # simplify messages - if not include_full_message: - messages = [] - for m in response.messages: - assert isinstance(m, Message) - messages += m.to_letta_message() - response.messages = messages + # if not include_full_message: + # messages = [] + # for m in response.messages: + # assert isinstance(m, Message) + # messages += m.to_letta_message() + # response.messages = messages return response diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 21cc881d..ef663726 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -3,7 +3,7 @@ from typing import List, Union from pydantic import BaseModel, Field from letta.schemas.enums import MessageStreamStatus -from letta.schemas.letta_message import LettaMessage +from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.utils import json_dumps @@ -21,7 +21,7 @@ class LettaResponse(BaseModel): usage (LettaUsageStatistics): The usage statistics """ - messages: Union[List[Message], List[LettaMessage]] = Field(..., description="The messages returned by the agent.") + messages: Union[List[Message], List[LettaMessageUnion]] = Field(..., description="The messages returned by the agent.") usage: LettaUsageStatistics = Field(..., description="The usage statistics of the agent.") def __str__(self): diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 8514dada..b0047cb6 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Union from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status from fastapi.responses import JSONResponse, StreamingResponse -from starlette.responses import StreamingResponse from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState @@ -359,7 +358,20 @@ def update_message( return server.update_agent_message(agent_id=agent_id, request=request) -@router.post("/{agent_id}/messages", response_model=None, operation_id="create_agent_message") +@router.post( + "/{agent_id}/messages", + response_model=None, + operation_id="create_agent_message", + responses={ + 200: { + "description": "Successful response", + "content": { + "application/json": {"schema": LettaResponse.model_json_schema()}, # Use model_json_schema() instead of model directly + "text/event-stream": {"description": "Server-Sent Events stream"}, + }, + } + }, +) async def send_message( agent_id: str, server: SyncServer = Depends(get_letta_server), @@ -373,7 +385,7 @@ async def send_message( """ actor = server.get_user_or_default(user_id=user_id) - return await send_message_to_agent( + result = await send_message_to_agent( server=server, agent_id=agent_id, user_id=actor.id, @@ -386,6 +398,7 @@ async def send_message( assistant_message_function_name=request.assistant_message_function_name, assistant_message_function_kwarg=request.assistant_message_function_kwarg, ) + return result # TODO: move this into server.py? diff --git a/tests/test_client.py b/tests/test_client.py index 807cd7ab..14e251e0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -13,7 +13,15 @@ from letta.constants import DEFAULT_PRESET from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageStreamStatus -from letta.schemas.letta_message import FunctionCallMessage, InternalMonologue +from letta.schemas.letta_message import ( + AssistantMessage, + FunctionCallMessage, + FunctionReturn, + InternalMonologue, + LettaMessage, + SystemMessage, + UserMessage, +) from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message @@ -121,6 +129,9 @@ def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: Agent message = "Hello, agent!" print("Sending message", message) response = client.user_message(agent_id=agent.id, message=message, include_full_message=True) + # Check the types coming back + assert all([isinstance(m, Message) for m in response.messages]), "All messages should be Message" + print("Response", response) assert isinstance(response.usage, LettaUsageStatistics) assert response.usage.step_count == 1 @@ -129,6 +140,25 @@ def test_agent_interactions(client: Union[LocalClient, RESTClient], agent: Agent assert isinstance(response.messages[0], Message) print(response.messages) + # test that it also works with LettaMessage + message = "Hello again, agent!" + print("Sending message", message) + response = client.user_message(agent_id=agent.id, message=message, include_full_message=False) + assert all([isinstance(m, LettaMessage) for m in response.messages]), "All messages should be LettaMessages" + + # We should also check that the types were cast properly + print("RESPONSE MESSAGES, client type:", type(client)) + print(response.messages) + for letta_message in response.messages: + assert type(letta_message) in [ + SystemMessage, + UserMessage, + InternalMonologue, + FunctionCallMessage, + FunctionReturn, + AssistantMessage, + ], f"Unexpected message type: {type(letta_message)}" + # TODO: add streaming tests