refactor: remove get_current_user and replace with direct header read (#1834)

This commit is contained in:
Charles Packer
2024-10-07 15:23:08 -07:00
committed by GitHub
parent c76cecb8cb
commit 5501f6d92f
10 changed files with 96 additions and 97 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.responses import StreamingResponse
@@ -40,12 +40,13 @@ router = APIRouter(prefix="/agents", tags=["agents"])
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
def list_agents(
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.list_agents(user_id=actor.id)
@@ -54,11 +55,12 @@ def list_agents(
def create_agent(
agent: CreateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Create a new agent with the specified configuration.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
agent.user_id = actor.id
# TODO: sarah make general
# TODO: eventually remove this
@@ -74,9 +76,10 @@ def update_agent(
agent_id: str,
update_agent: UpdateAgentState = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""Update an exsiting agent"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
update_agent.id = agent_id
return server.update_agent(update_agent, user_id=actor.id)
@@ -86,11 +89,12 @@ def update_agent(
def get_agent_state(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Get the state of the agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
# agent does not exist
@@ -103,11 +107,12 @@ def get_agent_state(
def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete an agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.delete_agent(user_id=actor.id, agent_id=agent_id)
@@ -120,7 +125,6 @@ def get_agent_sources(
"""
Get the sources associated with an agent.
"""
server.get_current_user()
return server.list_attached_sources(agent_id)
@@ -155,12 +159,13 @@ def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Update the core memory of a specific agent.
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
return memory
@@ -197,11 +202,12 @@ def get_agent_archival_memory(
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
# TODO need to add support for non-postgres here
# chroma will throw:
@@ -221,11 +227,12 @@ def insert_agent_archival_memory(
agent_id: str,
request: CreateArchivalMemory = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Insert a memory into an agent's archival memory store.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
@@ -238,11 +245,12 @@ def delete_agent_archival_memory(
memory_id: str,
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Delete a memory from an agent's archival memory store.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@@ -268,11 +276,12 @@ def get_agent_messages(
DEFAULT_MESSAGE_TOOL_KWARG,
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Retrieve message history for an agent.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
return server.get_agent_recall_cursor(
user_id=actor.id,
@@ -306,13 +315,14 @@ async def send_message(
agent_id: str,
server: SyncServer = Depends(get_letta_server),
request: LettaRequest = Body(...),
user_id: str = Header(None), # Extract user_id from header, default to None if not present
):
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
"""
actor = server.get_current_user()
actor = server.get_user_or_default(user_id=user_id)
# TODO(charles): support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"