From 02a4f13cb837b99aa20224c552a369619fc416df Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 16 Jan 2025 13:19:35 -0800 Subject: [PATCH] fix: manually expose LettaMessageUnion in openapi spec (#682) --- letta/schemas/letta_message.py | 24 ++++++++++++++++++++++ letta/schemas/letta_response.py | 10 +-------- letta/server/rest_api/app.py | 2 ++ letta/server/rest_api/routers/v1/agents.py | 17 +++++++++++++-- letta/server/rest_api/routers/v1/runs.py | 14 +++++++++++-- 5 files changed, 54 insertions(+), 13 deletions(-) diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index 45fcf361..a49ac049 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -217,3 +217,27 @@ LettaMessageUnion = Annotated[ Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage], Field(discriminator="message_type"), ] + + +def create_letta_message_union_schema(): + return { + "oneOf": [ + {"$ref": "#/components/schemas/SystemMessage-Output"}, + {"$ref": "#/components/schemas/UserMessage-Output"}, + {"$ref": "#/components/schemas/ReasoningMessage"}, + {"$ref": "#/components/schemas/ToolCallMessage"}, + {"$ref": "#/components/schemas/ToolReturnMessage"}, + {"$ref": "#/components/schemas/AssistantMessage-Output"}, + ], + "discriminator": { + "propertyName": "message_type", + "mapping": { + "system_message": "#/components/schemas/SystemMessage-Output", + "user_message": "#/components/schemas/UserMessage-Output", + "reasoning_message": "#/components/schemas/ReasoningMessage", + "tool_call_message": "#/components/schemas/ToolCallMessage", + "tool_return_message": "#/components/schemas/ToolReturnMessage", + "assistant_message": "#/components/schemas/AssistantMessage-Output", + }, + }, + } diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index fc969d66..ca34a532 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -28,15 +28,7 @@ class LettaResponse(BaseModel): description="The messages returned by the agent.", json_schema_extra={ "items": { - "oneOf": [ - {"$ref": "#/components/schemas/SystemMessage-Output"}, - {"$ref": "#/components/schemas/UserMessage-Output"}, - {"$ref": "#/components/schemas/ReasoningMessage"}, - {"$ref": "#/components/schemas/ToolCallMessage"}, - {"$ref": "#/components/schemas/ToolReturnMessage"}, - {"$ref": "#/components/schemas/AssistantMessage-Output"}, - ], - "discriminator": {"propertyName": "message_type"}, + "$ref": "#/components/schemas/LettaMessageUnion", } }, ) diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index b09d2ddf..b4bd35ec 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -16,6 +16,7 @@ from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError from letta.log import get_logger from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError +from letta.schemas.letta_message import create_letta_message_union_schema from letta.server.constants import REST_DEFAULT_PORT # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests @@ -67,6 +68,7 @@ def generate_openapi_schema(app: FastAPI): openai_docs["info"]["title"] = "OpenAI Assistants API" letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")} letta_docs["info"]["title"] = "Letta API" + letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema() # Split the API docs into Letta API, and OpenAI Assistants compatible API for name, docs in [ diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 53b0c290..d062a54a 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Optional, Union +from typing import Annotated, List, Optional, Union from fastapi import APIRouter, BackgroundTasks, Body, Depends, Header, HTTPException, Query, status from fastapi.responses import JSONResponse @@ -428,7 +428,20 @@ def delete_agent_archival_memory( return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) -@router.get("/{agent_id}/messages", response_model=Union[List[Message], List[LettaMessageUnion]], operation_id="list_agent_messages") +AgentMessagesResponse = Annotated[ + Union[List[Message], List[LettaMessageUnion]], + Field( + json_schema_extra={ + "anyOf": [ + {"type": "array", "items": {"$ref": "#/components/schemas/letta__schemas__message__Message"}}, + {"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}}, + ] + } + ), +] + + +@router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_agent_messages") def get_agent_messages( agent_id: str, server: "SyncServer" = Depends(get_letta_server), diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 34cbb889..63bd404f 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -1,6 +1,7 @@ -from typing import List, Optional +from typing import Annotated, List, Optional from fastapi import APIRouter, Depends, Header, HTTPException, Query +from pydantic import Field from letta.orm.enums import JobType from letta.orm.errors import NoResultFound @@ -60,7 +61,16 @@ def get_run( raise HTTPException(status_code=404, detail="Run not found") -@router.get("/{run_id}/messages", response_model=List[LettaMessageUnion], operation_id="get_run_messages") +RunMessagesResponse = Annotated[ + List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}}) +] + + +@router.get( + "/{run_id}/messages", + response_model=RunMessagesResponse, + operation_id="get_run_messages", +) async def get_run_messages( run_id: str, server: "SyncServer" = Depends(get_letta_server),