fix: patch user_id in header (#1843)

This commit is contained in:
Charles Packer
2024-10-08 10:21:07 -07:00
committed by GitHub
parent 1104438490
commit 6b35e87245
8 changed files with 44 additions and 41 deletions

View File

@@ -1,5 +1,5 @@
import uuid
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
@@ -43,7 +43,7 @@ router = APIRouter(prefix="/v1/threads", tags=["threads"])
def create_thread(
request: CreateThreadRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
@@ -68,7 +68,7 @@ def create_thread(
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
@@ -102,7 +102,7 @@ def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent_id = thread_id
@@ -146,7 +146,7 @@ def list_messages(
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id)
after_uuid = after if before else None

View File

@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException
@@ -30,7 +30,7 @@ router = APIRouter(prefix="/v1/chat/completions", tags=["chat_completions"])
async def create_chat_completion(
completion_request: ChatCompletionRequest = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Send a message to a Letta agent via a /chat/completions completion_request
The bearer token will be used to identify the user.

View File

@@ -40,7 +40,7 @@ router = APIRouter(prefix="/agents", tags=["agents"])
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
def list_agents(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all agents associated with a given user.
@@ -55,7 +55,7 @@ def list_agents(
def create_agent(
agent: CreateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new agent with the specified configuration.
@@ -76,7 +76,7 @@ def update_agent(
agent_id: str,
update_agent: UpdateAgentState = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Update an exsiting agent"""
actor = server.get_user_or_default(user_id=user_id)
@@ -89,7 +89,7 @@ def update_agent(
def get_agent_state(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the state of the agent.
@@ -107,7 +107,7 @@ def get_agent_state(
def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete an agent.
@@ -159,7 +159,7 @@ def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the core memory of a specific agent.
@@ -202,7 +202,7 @@ def get_agent_archival_memory(
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
@@ -227,7 +227,7 @@ def insert_agent_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Insert a memory into an agent's archival memory store.
@@ -245,7 +245,7 @@ def delete_agent_archival_memory(
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a memory from an agent's archival memory store.
@@ -276,7 +276,7 @@ def get_agent_messages(
DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
@@ -315,7 +315,7 @@ async def send_message(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.

View File

@@ -19,7 +19,7 @@ def list_blocks(
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
@@ -33,7 +33,7 @@ def list_blocks(
def create_block(
create_block: CreateBlock = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)

View File

@@ -13,7 +13,7 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
def list_jobs(
server: "SyncServer" = Depends(get_letta_server),
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all jobs.
@@ -34,7 +34,7 @@ def list_jobs(
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
def list_active_jobs(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all active jobs.

View File

@@ -1,6 +1,6 @@
import os
import tempfile
from typing import List
from typing import List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile
@@ -21,7 +21,7 @@ router = APIRouter(prefix="/sources", tags=["sources"])
def get_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get all sources
@@ -35,7 +35,7 @@ def get_source(
def get_source_id_by_name(
source_name: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a source by name
@@ -49,7 +49,7 @@ def get_source_id_by_name(
@router.get("/", response_model=List[Source], operation_id="list_sources")
def list_sources(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all data sources created by a user.
@@ -63,7 +63,7 @@ def list_sources(
def create_source(
source: SourceCreate,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new data source.
@@ -78,7 +78,7 @@ def update_source(
source_id: str,
source: SourceUpdate,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the name or documentation of an existing data source.
@@ -94,7 +94,7 @@ def update_source(
def delete_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a data source.
@@ -109,7 +109,7 @@ def attach_source_to_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Attach a data source to an existing agent.
@@ -127,7 +127,7 @@ def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
) -> None:
"""
Detach a data source from an existing agent.
@@ -143,7 +143,7 @@ def upload_file_to_source(
source_id: str,
background_tasks: BackgroundTasks,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Upload a file to a data source.
@@ -176,7 +176,7 @@ def upload_file_to_source(
def list_passages(
source_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all passages associated with a data source.
@@ -190,7 +190,7 @@ def list_passages(
def list_documents(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List all documents associated with a data source.

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException
@@ -13,7 +13,7 @@ router = APIRouter(prefix="/tools", tags=["tools"])
def delete_tool(
tool_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Delete a tool by name
@@ -43,7 +43,7 @@ def get_tool(
def get_tool_id(
tool_name: str,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a tool ID by name
@@ -60,7 +60,7 @@ def get_tool_id(
@router.get("/", response_model=List[Tool], operation_id="list_tools")
def list_all_tools(
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get a list of all tools available to agents created by a user
@@ -78,7 +78,7 @@ def create_tool(
tool: ToolCreate = Body(...),
update: bool = False,
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Create a new tool
@@ -98,7 +98,7 @@ def update_tool(
tool_id: str,
request: ToolUpdate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update an existing tool

View File

@@ -1921,7 +1921,10 @@ class SyncServer(Server):
if user_id is None:
return self.get_default_user()
else:
return self.get_user(user_id=user_id)
try:
return self.get_user(user_id=user_id)
except ValueError:
raise HTTPException(status_code=404, detail=f"User with id {user_id} not found")
def list_llm_models(self) -> List[LLMConfig]:
"""List available models"""