refactor: remove get_current_user and replace with direct header read (#1834)
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
@@ -40,12 +40,13 @@ router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
|
||||
def list_agents(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.list_agents(user_id=actor.id)
|
||||
|
||||
@@ -54,11 +55,12 @@ def list_agents(
|
||||
def create_agent(
|
||||
agent: CreateAgent = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
agent.user_id = actor.id
|
||||
# TODO: sarah make general
|
||||
# TODO: eventually remove this
|
||||
@@ -74,9 +76,10 @@ def update_agent(
|
||||
agent_id: str,
|
||||
update_agent: UpdateAgentState = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an exsiting agent"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
update_agent.id = agent_id
|
||||
return server.update_agent(update_agent, user_id=actor.id)
|
||||
@@ -86,11 +89,12 @@ def update_agent(
|
||||
def get_agent_state(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
|
||||
# agent does not exist
|
||||
@@ -103,11 +107,12 @@ def get_agent_state(
|
||||
def delete_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.delete_agent(user_id=actor.id, agent_id=agent_id)
|
||||
|
||||
@@ -120,7 +125,6 @@ def get_agent_sources(
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
server.get_current_user()
|
||||
|
||||
return server.list_attached_sources(agent_id)
|
||||
|
||||
@@ -155,12 +159,13 @@ def update_agent_memory(
|
||||
agent_id: str,
|
||||
request: Dict = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the core memory of a specific agent.
|
||||
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
memory = server.update_agent_core_memory(user_id=actor.id, agent_id=agent_id, new_memory_contents=request)
|
||||
return memory
|
||||
@@ -197,11 +202,12 @@ def get_agent_archival_memory(
|
||||
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# TODO need to add support for non-postgres here
|
||||
# chroma will throw:
|
||||
@@ -221,11 +227,12 @@ def insert_agent_archival_memory(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
|
||||
|
||||
@@ -238,11 +245,12 @@ def delete_agent_archival_memory(
|
||||
memory_id: str,
|
||||
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
@@ -268,11 +276,12 @@ def get_agent_messages(
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
|
||||
),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
@@ -306,13 +315,14 @@ async def send_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
user_id: str = Header(None), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It can optionally stream the response if 'stream_steps' or 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.get_current_user()
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
# TODO(charles): support sending multiple messages
|
||||
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
||||
|
||||
Reference in New Issue
Block a user