refactor: remove get_current_user and replace with direct header read (#1834)
This commit is contained in:
@@ -26,7 +26,6 @@ class CreateToolResponse(BaseModel):
|
||||
|
||||
|
||||
def setup_tools_index_router(server: SyncServer, interface: QueuingInterface):
|
||||
# get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.delete("/tools/{tool_name}", tags=["tools"])
|
||||
async def delete_tool(
|
||||
|
||||
@@ -5,8 +5,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
@@ -84,21 +83,6 @@ def create_application() -> "FastAPI":
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.middleware("http")
|
||||
async def set_current_user_middleware(request: Request, call_next):
|
||||
user_id = request.headers.get("user_id")
|
||||
if user_id:
|
||||
try:
|
||||
server.set_current_user(user_id)
|
||||
except ValueError as e:
|
||||
# Return an HTTP 401 Unauthorized response
|
||||
# raise HTTPException(status_code=401, detail=str(e))
|
||||
return JSONResponse(status_code=401, content={"detail": str(e)})
|
||||
else:
|
||||
server.set_current_user(None)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
for route in v1_routes:
|
||||
app.include_router(route, prefix=API_PREFIX)
|
||||
# this gives undocumented routes for "latest" and bare api calls.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
|
||||
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.schemas.agent import CreateAgent
|
||||
@@ -43,11 +43,12 @@ router = APIRouter(prefix="/v1/threads", tags=["threads"])
|
||||
def create_thread(
|
||||
request: CreateThreadRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
# TODO: use requests.description and requests.metadata fields
|
||||
# TODO: handle requests.file_ids and requests.tools
|
||||
# TODO: eventually allow request to override embedding/llm model
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
print("Create thread/agent", request)
|
||||
# create a letta agent
|
||||
@@ -67,8 +68,9 @@ def create_thread(
|
||||
def retrieve_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
|
||||
assert agent is not None
|
||||
return OpenAIThread(
|
||||
@@ -100,8 +102,9 @@ def create_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateMessageRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
agent_id = thread_id
|
||||
# create message object
|
||||
message = Message(
|
||||
@@ -143,8 +146,9 @@ def list_messages(
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id)
|
||||
after_uuid = after if before else None
|
||||
before_uuid = before if before else None
|
||||
agent_id = thread_id
|
||||
@@ -239,7 +243,6 @@ def create_run(
|
||||
request: CreateRunRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
server.get_current_user()
|
||||
|
||||
# TODO: add request.instructions as a message?
|
||||
agent_id = thread_id
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import FunctionCall, LettaMessage
|
||||
@@ -30,12 +30,14 @@ router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"])
|
||||
async def create_chat_completion(
|
||||
completion_request: ChatCompletionRequest = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Send a message to a Letta agent via a /chat/completions completion_request
|
||||
The bearer token will be used to identify the user.
|
||||
The 'user' field in the completion_request should be set to the agent ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_id = completion_request.user
|
||||
if agent_id is None:
|
||||
raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field")
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
@@ -40,12 +40,13 @@ router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
|
||||
def list_agents(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.list_agents(user_id=actor.id)
|
||||
|
||||
@@ -54,11 +55,12 @@ def list_agents(
|
||||
def create_agent(
|
||||
agent: CreateAgent = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
agent.user_id = actor.id
|
||||
# TODO: sarah make general
|
||||
# TODO: eventually remove this
|
||||
@@ -74,9 +76,10 @@ def update_agent(
|
||||
agent_id: str,
|
||||
update_agent: UpdateAgentState = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an exsiting agent"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
update_agent.id = agent_id
|
||||
return server.update_agent(update_agent, user_id=actor.id)
|
||||
@@ -86,11 +89,12 @@ def update_agent(
|
||||
def get_agent_state(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
|
||||
# agent does not exist
|
||||
@@ -103,11 +107,12 @@ def get_agent_state(
|
||||
def delete_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.delete_agent(user_id=actor.id, agent_id=agent_id)
|
||||
|
||||
@@ -120,7 +125,6 @@ def get_agent_sources(
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
server.get_current_user()
|
||||
|
||||
return server.list_attached_sources(agent_id)
|
||||
|
||||
@@ -155,12 +159,13 @@ def update_agent_memory(
|
||||
agent_id: str,
|
||||
request: Dict = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the core memory of a specific agent.
|
||||
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
|
||||
return memory
|
||||
@@ -197,11 +202,12 @@ def get_agent_archival_memory(
|
||||
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# TODO need to add support for non-postgres here
|
||||
# chroma will throw:
|
||||
@@ -221,11 +227,12 @@ def insert_agent_archival_memory(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
|
||||
|
||||
@@ -238,11 +245,12 @@ def delete_agent_archival_memory(
|
||||
memory_id: str,
|
||||
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
@@ -268,11 +276,12 @@ def get_agent_messages(
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
|
||||
),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
@@ -306,13 +315,14 @@ async def send_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# TODO(charles): support sending multiple messages
|
||||
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.schemas.block import Block, CreateBlock, UpdateBlock
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -19,8 +19,9 @@ def list_blocks(
|
||||
templates_only: bool = Query(True, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
blocks = server.get_blocks(user_id=actor.id, label=label, template=templates_only, name=name)
|
||||
if blocks is None:
|
||||
@@ -32,8 +33,9 @@ def list_blocks(
|
||||
def create_block(
|
||||
create_block: CreateBlock = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
create_block.user_id = actor.id
|
||||
return server.create_block(user_id=actor.id, request=create_block)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi import APIRouter, Depends, Header, Query
|
||||
|
||||
from letta.schemas.job import Job
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -13,11 +13,12 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
def list_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all jobs.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# TODO: add filtering by status
|
||||
jobs = server.list_jobs(user_id=actor.id)
|
||||
@@ -33,11 +34,12 @@ def list_jobs(
|
||||
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
|
||||
def list_active_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.list_active_jobs(user_id=actor.id)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile
|
||||
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.job import Job
|
||||
@@ -21,11 +21,12 @@ router = APIRouter(prefix="/sources", tags=["sources"])
|
||||
def get_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get all sources
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_source(source_id=source_id, user_id=actor.id)
|
||||
|
||||
@@ -34,11 +35,12 @@ def get_source(
|
||||
def get_source_id_by_name(
|
||||
source_name: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a source by name
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
source_id = server.get_source_id(source_name=source_name, user_id=actor.id)
|
||||
return source_id
|
||||
@@ -47,11 +49,12 @@ def get_source_id_by_name(
|
||||
@router.get("/", response_model=List[Source], operation_id="list_sources")
|
||||
def list_sources(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all data sources created by a user.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.list_all_sources(user_id=actor.id)
|
||||
|
||||
@@ -60,11 +63,12 @@ def list_sources(
|
||||
def create_source(
|
||||
source: SourceCreate,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.create_source(request=source, user_id=actor.id)
|
||||
|
||||
@@ -74,11 +78,13 @@ def update_source(
|
||||
source_id: str,
|
||||
source: SourceUpdate,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the name or documentation of an existing data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
assert source.id == source_id, "Source ID in path must match ID in request body"
|
||||
|
||||
return server.update_source(request=source, user_id=actor.id)
|
||||
@@ -88,11 +94,12 @@ def update_source(
|
||||
def delete_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
server.delete_source(source_id=source_id, user_id=actor.id)
|
||||
|
||||
@@ -102,11 +109,12 @@ def attach_source_to_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Attach a data source to an existing agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
@@ -119,11 +127,12 @@ def detach_source_from_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
) -> None:
|
||||
"""
|
||||
Detach a data source from an existing agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
|
||||
|
||||
@@ -134,11 +143,12 @@ def upload_file_to_source(
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
source = server.ms.get_source(source_id=source_id, user_id=actor.id)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
@@ -166,11 +176,12 @@ def upload_file_to_source(
|
||||
def list_passages(
|
||||
source_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
|
||||
return passages
|
||||
|
||||
@@ -179,11 +190,12 @@ def list_passages(
|
||||
def list_documents(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all documents associated with a data source.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id)
|
||||
return documents
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
@@ -13,11 +13,12 @@ router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
def delete_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# actor = server.get_current_user()
|
||||
# actor = server.get_user_or_default(user_id=user_id)
|
||||
server.delete_tool(tool_id=tool_id)
|
||||
|
||||
|
||||
@@ -42,11 +43,12 @@ def get_tool(
|
||||
def get_tool_id(
|
||||
tool_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a tool ID by name
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
tool_id = server.get_tool_id(tool_name, user_id=actor.id)
|
||||
if tool_id is None:
|
||||
@@ -58,11 +60,12 @@ def get_tool_id(
|
||||
@router.get("/", response_model=List[Tool], operation_id="list_tools")
|
||||
def list_all_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor.id
|
||||
|
||||
# TODO: add back when user-specific
|
||||
@@ -75,11 +78,12 @@ def create_tool(
|
||||
tool: ToolCreate = Body(...),
|
||||
update: bool = False,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.create_tool(
|
||||
request=tool,
|
||||
@@ -94,10 +98,11 @@ def update_tool(
|
||||
tool_id: str,
|
||||
request: ToolUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
assert tool_id == request.id, "Tool ID in path must match tool ID in request body"
|
||||
server.get_current_user()
|
||||
# actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.update_tool(request)
|
||||
|
||||
@@ -1064,7 +1064,11 @@ class SyncServer(Server):
|
||||
|
||||
def get_user(self, user_id: str) -> User:
|
||||
"""Get the user"""
|
||||
return self.ms.get_user(user_id=user_id)
|
||||
user = self.ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
raise ValueError(f"User with user_id {user_id} does not exist")
|
||||
else:
|
||||
return user
|
||||
|
||||
def get_agent_memory(self, agent_id: str) -> Memory:
|
||||
"""Return the memory of an agent (core memory)"""
|
||||
@@ -1880,20 +1884,6 @@ class SyncServer(Server):
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
return letta_agent.retry_message()
|
||||
|
||||
def set_current_user(self, user_id: Optional[str]):
|
||||
"""Very hacky way to set the current user for the server, to be replaced once server becomes stateless
|
||||
|
||||
NOTE: clearly not thread-safe, only exists to provide basic user_id support for REST API for now
|
||||
"""
|
||||
|
||||
# Make sure the user_id actually exists
|
||||
if user_id is not None:
|
||||
user_obj = self.get_user(user_id)
|
||||
if not user_obj:
|
||||
raise ValueError(f"User with id {user_id} not found")
|
||||
|
||||
self._current_user = user_id
|
||||
|
||||
def get_default_user(self) -> User:
|
||||
|
||||
from letta.constants import (
|
||||
@@ -1910,8 +1900,9 @@ class SyncServer(Server):
|
||||
self.ms.create_organization(org)
|
||||
|
||||
# check if default user exists
|
||||
default_user = self.get_user(DEFAULT_USER_ID)
|
||||
if not default_user:
|
||||
try:
|
||||
self.get_user(DEFAULT_USER_ID)
|
||||
except ValueError:
|
||||
user = User(name=DEFAULT_USER_NAME, org_id=DEFAULT_ORG_ID, id=DEFAULT_USER_ID)
|
||||
self.ms.create_user(user)
|
||||
|
||||
@@ -1922,23 +1913,12 @@ class SyncServer(Server):
|
||||
# check if default org exists
|
||||
return self.get_user(DEFAULT_USER_ID)
|
||||
|
||||
# TODO(ethan) wire back to real method in future ORM PR
|
||||
def get_current_user(self) -> User:
|
||||
"""Returns the currently authed user.
|
||||
|
||||
Since server is the core gateway this needs to pass through server as the
|
||||
first touchpoint.
|
||||
"""
|
||||
|
||||
# Check if _current_user is set and if it's non-null:
|
||||
if hasattr(self, "_current_user") and self._current_user is not None:
|
||||
current_user = self.get_user(self._current_user)
|
||||
if not current_user:
|
||||
warnings.warn(f"Provided user '{self._current_user}' not found, using default user")
|
||||
else:
|
||||
return current_user
|
||||
|
||||
return self.get_default_user()
|
||||
def get_user_or_default(self, user_id: Optional[str]) -> User:
|
||||
"""Get the user object for user_id if it exists, otherwise return the default user object"""
|
||||
if user_id is None:
|
||||
return self.get_default_user()
|
||||
else:
|
||||
return self.get_user(user_id=user_id)
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
"""List available models"""
|
||||
|
||||
Reference in New Issue
Block a user