fix: manually expose LettaMessageUnion in openapi spec (#682)

This commit is contained in:
cthomas
2025-01-16 13:19:35 -08:00
committed by GitHub
parent 97e823d29a
commit 02a4f13cb8
5 changed files with 54 additions and 13 deletions

View File

@@ -217,3 +217,27 @@ LettaMessageUnion = Annotated[
Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
Field(discriminator="message_type"),
]
def create_letta_message_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/SystemMessage-Output"},
{"$ref": "#/components/schemas/UserMessage-Output"},
{"$ref": "#/components/schemas/ReasoningMessage"},
{"$ref": "#/components/schemas/ToolCallMessage"},
{"$ref": "#/components/schemas/ToolReturnMessage"},
{"$ref": "#/components/schemas/AssistantMessage-Output"},
],
"discriminator": {
"propertyName": "message_type",
"mapping": {
"system_message": "#/components/schemas/SystemMessage-Output",
"user_message": "#/components/schemas/UserMessage-Output",
"reasoning_message": "#/components/schemas/ReasoningMessage",
"tool_call_message": "#/components/schemas/ToolCallMessage",
"tool_return_message": "#/components/schemas/ToolReturnMessage",
"assistant_message": "#/components/schemas/AssistantMessage-Output",
},
},
}

View File

@@ -28,15 +28,7 @@ class LettaResponse(BaseModel):
description="The messages returned by the agent.",
json_schema_extra={
"items": {
"oneOf": [
{"$ref": "#/components/schemas/SystemMessage-Output"},
{"$ref": "#/components/schemas/UserMessage-Output"},
{"$ref": "#/components/schemas/ReasoningMessage"},
{"$ref": "#/components/schemas/ToolCallMessage"},
{"$ref": "#/components/schemas/ToolReturnMessage"},
{"$ref": "#/components/schemas/AssistantMessage-Output"},
],
"discriminator": {"propertyName": "message_type"},
"$ref": "#/components/schemas/LettaMessageUnion",
}
},
)

View File

@@ -16,6 +16,7 @@ from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError
from letta.log import get_logger
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
from letta.schemas.letta_message import create_letta_message_union_schema
from letta.server.constants import REST_DEFAULT_PORT
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
@@ -67,6 +68,7 @@ def generate_openapi_schema(app: FastAPI):
openai_docs["info"]["title"] = "OpenAI Assistants API"
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
letta_docs["info"]["title"] = "Letta API"
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
# Split the API docs into Letta API, and OpenAI Assistants compatible API
for name, docs in [

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import List, Optional, Union
from typing import Annotated, List, Optional, Union
from fastapi import APIRouter, BackgroundTasks, Body, Depends, Header, HTTPException, Query, status
from fastapi.responses import JSONResponse
@@ -428,7 +428,20 @@ def delete_agent_archival_memory(
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@router.get("/{agent_id}/messages", response_model=Union[List[Message], List[LettaMessageUnion]], operation_id="list_agent_messages")
AgentMessagesResponse = Annotated[
Union[List[Message], List[LettaMessageUnion]],
Field(
json_schema_extra={
"anyOf": [
{"type": "array", "items": {"$ref": "#/components/schemas/letta__schemas__message__Message"}},
{"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}},
]
}
),
]
@router.get("/{agent_id}/messages", response_model=AgentMessagesResponse, operation_id="list_agent_messages")
def get_agent_messages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),

View File

@@ -1,6 +1,7 @@
from typing import List, Optional
from typing import Annotated, List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from pydantic import Field
from letta.orm.enums import JobType
from letta.orm.errors import NoResultFound
@@ -60,7 +61,16 @@ def get_run(
raise HTTPException(status_code=404, detail="Run not found")
@router.get("/{run_id}/messages", response_model=List[LettaMessageUnion], operation_id="get_run_messages")
RunMessagesResponse = Annotated[
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
]
@router.get(
"/{run_id}/messages",
response_model=RunMessagesResponse,
operation_id="get_run_messages",
)
async def get_run_messages(
run_id: str,
server: "SyncServer" = Depends(get_letta_server),