refactor: remove get_current_user and replace with direct header read (#1834)

This commit is contained in:
Charles Packer
2024-10-07 15:23:08 -07:00
committed by GitHub
parent c76cecb8cb
commit 5501f6d92f
10 changed files with 96 additions and 97 deletions

View File

@@ -26,7 +26,6 @@ class CreateToolResponse(BaseModel):
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
# get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool(

View File

@@ -5,8 +5,7 @@ from pathlib import Path
from typing import Optional
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from letta.server.constants import REST_DEFAULT_PORT
@@ -84,21 +83,6 @@ def create_application() -> "FastAPI":
allow_headers=["*"],
)
@app.middleware("http")
async def set_current_user_middleware(request: Request, call_next):
user_id = request.headers.get("user_id")
if user_id:
try:
server.set_current_user(user_id)
except ValueError as e:
# Return an HTTP 401 Unauthorized response
# raise HTTPException(status_code=401, detail=str(e))
return JSONResponse(status_code=401, content={"detail": str(e)})
else:
server.set_current_user(None)
response = await call_next(request)
return response
for route in v1_routes:
app.include_router(route, prefix=API_PREFIX)
# this gives undocumented routes for "latest" and bare api calls.

View File

@@ -1,7 +1,7 @@
import uuid
from typing import TYPE_CHECKING, List
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import CreateAgent
@@ -43,11 +43,12 @@ router = APIRouter(prefix="/v1/threads", tags=["threads"])
def create_thread(
request: CreateThreadRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
print("Create thread/agent", request)
# create a letta agent
@@ -67,8 +68,9 @@ def create_thread(
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
assert agent is not None
return OpenAIThread(
@@ -100,8 +102,9 @@ def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent_id = thread_id
# create message object
message = Message(
@@ -143,8 +146,9 @@ def list_messages(
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id)
after_uuid = after if before else None
before_uuid = before if before else None
agent_id = thread_id
@@ -239,7 +243,6 @@ def create_run(
request: CreateRunRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
):
server.get_current_user()
# TODO: add request.instructions as a message?
agent_id = thread_id

View File

@@ -1,7 +1,7 @@
import json
from typing import TYPE_CHECKING
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import FunctionCall, LettaMessage
@@ -30,12 +30,14 @@ router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"])
async def create_chat_completion(
completion_request: ChatCompletionRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""Send a message to a Letta agent via a /chat/completions completion_request
The bearer token will be used to identify the user.
The 'user' field in the completion_request should be set to the agent ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent_id = completion_request.user
if agent_id is None:
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")

View File

@@ -2,7 +2,7 @@ import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.responses import StreamingResponse
@@ -40,12 +40,13 @@ router = APIRouter(prefix="/agents", tags=["agents"])
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
def list_agents(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.list_agents(user_id=actor.id)
@@ -54,11 +55,12 @@ def list_agents(
def create_agent(
agent: CreateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Create a new agent with the specified configuration.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent.user_id = actor.id
# TODO: sarah make general
# TODO: eventually remove this
@@ -74,9 +76,10 @@ def update_agent(
agent_id: str,
update_agent: UpdateAgentState = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""Update an exsiting agent"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
update_agent.id = agent_id
return server.update_agent(update_agent, user_id=actor.id)
@@ -86,11 +89,12 @@ def update_agent(
def get_agent_state(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Get the state of the agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
# agent does not exist
@@ -103,11 +107,12 @@ def get_agent_state(
def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete an agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.delete_agent(user_id=actor.id, agent_id=agent_id)
@@ -120,7 +125,6 @@ def get_agent_sources(
"""
Get the sources associated with an agent.
"""
server.get_current_user()
return server.list_attached_sources(agent_id)
@@ -155,12 +159,13 @@ def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Update the core memory of a specific agent.
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
return memory
@@ -197,11 +202,12 @@ def get_agent_archival_memory(
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
# TODO need to add support for non-postgres here
# chroma will throw:
@@ -221,11 +227,12 @@ def insert_agent_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Insert a memory into an agent's archival memory store.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
@@ -238,11 +245,12 @@ def delete_agent_archival_memory(
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete a memory from an agent's archival memory store.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@@ -268,11 +276,12 @@ def get_agent_messages(
DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.get_agent_recall_cursor(
user_id=actor.id,
@@ -306,13 +315,14 @@ async def send_message(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
# TODO(charles): support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.schemas.block import Block, CreateBlock, UpdateBlock
from letta.server.rest_api.utils import get_letta_server
@@ -19,8 +19,9 @@ def list_blocks(
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name)
if blocks is None:
@@ -32,8 +33,9 @@ def list_blocks(
def create_block(
create_block: CreateBlock = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
create_block.user_id = actor.id
return server.create_block(user_id=actor.id, request=create_block)

View File

@@ -1,6 +1,6 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, Query
from fastapi import APIRouter, Depends, Header, Query
from letta.schemas.job import Job
from letta.server.rest_api.utils import get_letta_server
@@ -13,11 +13,12 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all jobs.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
# TODO: add filtering by status
jobs = server.list_jobs(user_id=actor.id)
@@ -33,11 +34,12 @@ def list_jobs(
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all active jobs.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.list_active_jobs(user_id=actor.id)

View File

@@ -2,7 +2,7 @@ import os
import tempfile
from typing import List
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile
from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile
from letta.schemas.document import Document
from letta.schemas.job import Job
@@ -21,11 +21,12 @@ router = APIRouter(prefix="/sources", tags=["sources"])
def get_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Get all sources
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.get_source(source_id=source_id, user_id=actor.id)
@@ -34,11 +35,12 @@ def get_source(
def get_source_id_by_name(
source_name: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Get a source by name
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
source_id = server.get_source_id(source_name=source_name, user_id=actor.id)
return source_id
@@ -47,11 +49,12 @@ def get_source_id_by_name(
@router.get("/", response_model=List[Source], operation_id="list_sources")
def list_sources(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all data sources created by a user.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.list_all_sources(user_id=actor.id)
@@ -60,11 +63,12 @@ def list_sources(
def create_source(
source: SourceCreate,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Create a new data source.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.create_source(request=source, user_id=actor.id)
@@ -74,11 +78,13 @@ def update_source(
source_id: str,
source: SourceUpdate,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Update the name or documentation of an existing data source.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
assert source.id == source_id, "Source ID in path must match ID in request body"
return server.update_source(request=source, user_id=actor.id)
@@ -88,11 +94,12 @@ def update_source(
def delete_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete a data source.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
server.delete_source(source_id=source_id, user_id=actor.id)
@@ -102,11 +109,12 @@ def attach_source_to_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Attach a data source to an existing agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
assert source is not None, f"Source with id={source_id} not found."
@@ -119,11 +127,12 @@ def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
) -> None:
"""
Detach a data source from an existing agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
@@ -134,11 +143,12 @@ def upload_file_to_source(
source_id: str,
background_tasks: BackgroundTasks,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Upload a file to a data source.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
assert source is not None, f"Source with id={source_id} not found."
@@ -166,11 +176,12 @@ def upload_file_to_source(
def list_passages(
source_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all passages associated with a data source.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
return passages
@@ -179,11 +190,12 @@ def list_passages(
def list_documents(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all documents associated with a data source.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id)
return documents

View File

@@ -1,6 +1,6 @@
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.server.rest_api.utils import get_letta_server
@@ -13,11 +13,12 @@ router = APIRouter(prefix="/tools", tags=["tools"])
def delete_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete a tool by name
"""
# actor = server.get_current_user()
# actor = server.get_user_or_default(user_id=user_id)
server.delete_tool(tool_id=tool_id)
@@ -42,11 +43,12 @@ def get_tool(
def get_tool_id(
tool_name: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Get a tool ID by name
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
tool_id = server.get_tool_id(tool_name, user_id=actor.id)
if tool_id is None:
@@ -58,11 +60,12 @@ def get_tool_id(
@router.get("/", response_model=List[Tool], operation_id="list_tools")
def list_all_tools(
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Get a list of all tools available to agents created by a user
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
actor.id
# TODO: add back when user-specific
@@ -75,11 +78,12 @@ def create_tool(
tool: ToolCreate = Body(...),
update: bool = False,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Create a new tool
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.create_tool(
request=tool,
@@ -94,10 +98,11 @@ def update_tool(
tool_id: str,
request: ToolUpdate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Update an existing tool
"""
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
server.get_current_user()
# actor = server.get_user_or_default(user_id=user_id)
return server.update_tool(request)

View File

@@ -1064,7 +1064,11 @@ class SyncServer(Server):
def get_user(self, user_id: str) -> User:
"""Get the user"""
return self.ms.get_user(user_id=user_id)
user = self.ms.get_user(user_id=user_id)
if user is None:
raise ValueError(f"User with user_id {user_id} does not exist")
else:
return user
def get_agent_memory(self, agent_id: str) -> Memory:
"""Return the memory of an agent (core memory)"""
@@ -1880,20 +1884,6 @@ class SyncServer(Server):
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.retry_message()
def set_current_user(self, user_id: Optional[str]):
"""Very hacky way to set the current user for the server, to be replaced once server becomes stateless
NOTE: clearly not thread-safe, only exists to provide basic user_id support for REST API for now
"""
# Make sure the user_id actually exists
if user_id is not None:
user_obj = self.get_user(user_id)
if not user_obj:
raise ValueError(f"User with id {user_id} not found")
self._current_user = user_id
def get_default_user(self) -> User:
from letta.constants import (
@@ -1910,8 +1900,9 @@ class SyncServer(Server):
self.ms.create_organization(org)
# check if default user exists
default_user = self.get_user(DEFAULT_USER_ID)
if not default_user:
try:
self.get_user(DEFAULT_USER_ID)
except ValueError:
user = User(name=DEFAULT_USER_NAME, org_id=DEFAULT_ORG_ID, id=DEFAULT_USER_ID)
self.ms.create_user(user)
@@ -1922,23 +1913,12 @@ class SyncServer(Server):
# check if default org exists
return self.get_user(DEFAULT_USER_ID)
# TODO(ethan) wire back to real method in future ORM PR
def get_current_user(self) -> User:
"""Returns the currently authed user.
Since server is the core gateway this needs to pass through server as the
first touchpoint.
"""
# Check if _current_user is set and if it's non-null:
if hasattr(self, "_current_user") and self._current_user is not None:
current_user = self.get_user(self._current_user)
if not current_user:
warnings.warn(f"Provided user '{self._current_user}' not found, using default user")
else:
return current_user
return self.get_default_user()
def get_user_or_default(self, user_id: Optional[str]) -> User:
"""Get the user object for user_id if it exists, otherwise return the default user object"""
if user_id is None:
return self.get_default_user()
else:
return self.get_user(user_id=user_id)
def list_llm_models(self) -> List[LLMConfig]:
"""List available models"""