Files
letta-server/letta/server/rest_api/routers/openai/assistants/threads.py

339 lines
14 KiB
Python

import uuid
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import CreateAgent
from letta.schemas.enums import MessageRole
from letta.schemas.message import Message
from letta.schemas.openai.openai import (
MessageFile,
OpenAIMessage,
OpenAIRun,
OpenAIRunStep,
OpenAIThread,
Text,
)
from letta.server.rest_api.routers.openai.assistants.schemas import (
CreateMessageRequest,
CreateRunRequest,
CreateThreadRequest,
CreateThreadRunRequest,
DeleteThreadResponse,
ListMessagesResponse,
ModifyMessageRequest,
ModifyRunRequest,
ModifyThreadRequest,
OpenAIThread,
SubmitToolOutputsToRunRequest,
)
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
if TYPE_CHECKING:
from letta.utils import get_utc_time
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
router = APIRouter(prefix="/v1/threads", tags=["threads"])
@router.post("/", response_model=OpenAIThread)
def create_thread(
request: CreateThreadRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
actor = server.get_user_or_default(user_id=user_id)
print("Create thread/agent", request)
# create a letta agent
agent_state = server.create_agent(
request=CreateAgent(),
user_id=actor.id,
)
# TODO: insert messages into recall memory
return OpenAIThread(
id=str(agent_state.id),
created_at=int(agent_state.created_at.timestamp()),
metadata={}, # TODO add metadata?
)
@router.get("/{thread_id}", response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
assert agent is not None
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
metadata={}, # TODO add metadata?
)
@router.get("/{thread_id}", response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/{thread_id}", response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent_id = thread_id
# create message object
message = Message(
user_id=actor.id,
agent_id=agent_id,
role=MessageRole(request.role),
text=request.content,
model=None,
tool_calls=None,
tool_call_id=None,
name=None,
)
agent = server.load_agent(agent_id=agent_id)
# add message to agent
agent._append_to_messages([message])
openai_message = OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
return openai_message
@router.get("/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id)
after_uuid = after if before else None
before_uuid = before if before else None
agent_id = thread_id
reverse = True if (order == "desc") else False
json_messages = server.get_agent_recall_cursor(
user_id=actor.id,
agent_id=agent_id,
limit=limit,
after=after_uuid,
before=before_uuid,
order_by="created_at",
reverse=reverse,
)
assert isinstance(json_messages, List)
assert all([isinstance(message, Message) for message in json_messages])
assert isinstance(json_messages[0], Message)
print(json_messages[0].text)
# convert to openai style messages
openai_messages = []
for message in json_messages:
assert isinstance(message, Message)
openai_messages.append(
OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=str(message.role),
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
)
print("MESSAGES", openai_messages)
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
@router.get("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
server: SyncServer = Depends(get_letta_server),
):
agent_id = thread_id
message = server.get_agent_message(agent_id=agent_id, message_id=message_id)
assert message is not None
return OpenAIMessage(
id=message_id,
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
@router.get("/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
):
# TODO: add request.instructions as a message?
agent_id = thread_id
# TODO: override preset of agent with request.assistant_id
agent = server.load_agent(agent_id=agent_id)
agent.inner_step(messages=[]) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(get_utc_time().timestamp())
return OpenAIRun(
id=run_id,
created_at=create_time,
thread_id=str(agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
status="completed", # TODO: eventaully allow offline execution
expires_at=create_time,
model=agent.agent_state.llm_config.model,
instructions=request.instructions,
)
@router.post("/runs", tags=["runs"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")