chore: change the name of user id to actor (#1098)
This commit is contained in:
@@ -44,7 +44,7 @@ def list_agents(
|
||||
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.",
|
||||
),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(None, description="Limit for pagination"),
|
||||
@@ -58,7 +58,7 @@ def list_agents(
|
||||
List all agents associated with a given user.
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Use dictionary comprehension to build kwargs dynamically
|
||||
kwargs = {
|
||||
@@ -91,12 +91,12 @@ def list_agents(
|
||||
def retrieve_agent_context_window(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_context_window(agent_id=agent_id, actor=actor)
|
||||
|
||||
@@ -107,21 +107,21 @@ class CreateAgentRequest(CreateAgent):
|
||||
"""
|
||||
|
||||
# Override the user_id field to exclude it from the request body validation
|
||||
user_id: Optional[str] = Field(None, exclude=True)
|
||||
actor_id: Optional[str] = Field(None, exclude=True)
|
||||
|
||||
|
||||
@router.post("/", response_model=AgentState, operation_id="create_agent")
|
||||
def create_agent(
|
||||
agent: CreateAgentRequest = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.create_agent(agent, actor=actor)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
@@ -133,10 +133,10 @@ def modify_agent(
|
||||
agent_id: str,
|
||||
update_agent: UpdateAgent = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor)
|
||||
|
||||
|
||||
@@ -144,10 +144,10 @@ def modify_agent(
|
||||
def list_agent_tools(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Get tools from an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@@ -156,12 +156,12 @@ def attach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a tool to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@@ -170,12 +170,12 @@ def detach_tool(
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a tool from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@@ -184,12 +184,12 @@ def attach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a source to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@@ -198,12 +198,12 @@ def detach_source(
|
||||
agent_id: str,
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a source from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
|
||||
|
||||
|
||||
@@ -211,12 +211,12 @@ def detach_source(
|
||||
def retrieve_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@@ -228,12 +228,12 @@ def retrieve_agent(
|
||||
def delete_agent(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"})
|
||||
@@ -245,12 +245,12 @@ def delete_agent(
|
||||
def list_agent_sources(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@@ -259,13 +259,13 @@ def list_agent_sources(
|
||||
def retrieve_agent_memory(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_memory(agent_id=agent_id, actor=actor)
|
||||
|
||||
@@ -275,12 +275,12 @@ def retrieve_core_memory_block(
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve a memory block from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
@@ -292,12 +292,12 @@ def retrieve_core_memory_block(
|
||||
def list_core_memory_blocks(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memory blocks of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
|
||||
return agent.memory.blocks
|
||||
@@ -311,12 +311,12 @@ def modify_core_memory_block(
|
||||
block_label: str,
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Updates a memory block of an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
|
||||
@@ -332,12 +332,12 @@ def attach_core_memory_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Attach a block to an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@@ -346,12 +346,12 @@ def detach_core_memory_block(
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Detach a block from an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@@ -362,12 +362,12 @@ def list_archival_memory(
|
||||
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_archival(
|
||||
user_id=actor.id,
|
||||
@@ -383,12 +383,12 @@ def create_archival_memory(
|
||||
agent_id: str,
|
||||
request: CreateArchivalMemory = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor)
|
||||
|
||||
@@ -401,12 +401,12 @@ def delete_archival_memory(
|
||||
memory_id: str,
|
||||
# memory_id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
server.delete_archival_memory(memory_id=memory_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
@@ -427,12 +427,12 @@ def list_messages(
|
||||
use_assistant_message: bool = Query(True, description="Whether to use assistant messages"),
|
||||
assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."),
|
||||
assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.get_agent_recall(
|
||||
user_id=actor.id,
|
||||
@@ -454,13 +454,13 @@ def modify_message(
|
||||
message_id: str,
|
||||
request: MessageUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
# TODO: Get rid of agent_id here, it's not really relevant
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
|
||||
|
||||
@@ -474,13 +474,13 @@ async def send_message(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
@@ -513,14 +513,14 @@ async def send_message_streaming(
|
||||
agent_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
@@ -590,13 +590,13 @@ async def send_message_async(
|
||||
background_tasks: BackgroundTasks,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaRequest = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Asynchronously process a user message and return a run object.
|
||||
The actual processing happens in the background, and the status can be checked using the run ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Create a new job
|
||||
run = Run(
|
||||
@@ -635,8 +635,8 @@ def reset_messages(
|
||||
agent_id: str,
|
||||
add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Resets the messages for an agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages)
|
||||
|
||||
@@ -21,9 +21,9 @@ def list_blocks(
|
||||
templates_only: bool = Query(True, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name)
|
||||
|
||||
|
||||
@@ -31,9 +31,9 @@ def list_blocks(
|
||||
def create_block(
|
||||
create_block: CreateBlock = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
block = Block(**create_block.model_dump())
|
||||
return server.block_manager.create_or_update_block(actor=actor, block=block)
|
||||
|
||||
@@ -43,9 +43,9 @@ def modify_block(
|
||||
block_id: str,
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor)
|
||||
|
||||
|
||||
@@ -53,9 +53,9 @@ def modify_block(
|
||||
def delete_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.block_manager.delete_block(block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@@ -63,10 +63,10 @@ def delete_block(
|
||||
def retrieve_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
print("call get block", block_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
|
||||
if block is None:
|
||||
@@ -80,13 +80,13 @@ def retrieve_block(
|
||||
def list_agents_for_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Retrieves all agents associated with the specified block.
|
||||
Raises a 404 if the block does not exist.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor)
|
||||
return agents
|
||||
|
||||
@@ -22,13 +22,13 @@ def list_identities(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a list of all identities in the database
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
identities = server.identity_manager.list_identities(
|
||||
name=name,
|
||||
@@ -51,10 +51,10 @@ def list_identities(
|
||||
def retrieve_identity(
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.get_identity(identity_id=identity_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -64,11 +64,11 @@ def retrieve_identity(
|
||||
def create_identity(
|
||||
identity: IdentityCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.create_identity(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -80,11 +80,11 @@ def create_identity(
|
||||
def upsert_identity(
|
||||
identity: IdentityCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.upsert_identity(identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -97,10 +97,10 @@ def modify_identity(
|
||||
identity_id: str,
|
||||
identity: IdentityUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -112,10 +112,10 @@ def modify_identity(
|
||||
def delete_identity(
|
||||
identity_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete an identity by its identifier key
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.identity_manager.delete_identity(identity_id=identity_id, actor=actor)
|
||||
|
||||
@@ -15,12 +15,12 @@ router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
def list_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all jobs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# TODO: add filtering by status
|
||||
jobs = server.job_manager.list_jobs(actor=actor)
|
||||
@@ -35,12 +35,12 @@ def list_jobs(
|
||||
@router.get("/active", response_model=List[Job], operation_id="list_active_jobs")
|
||||
def list_active_jobs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running])
|
||||
|
||||
@@ -48,13 +48,13 @@ def list_active_jobs(
|
||||
@router.get("/{job_id}", response_model=Job, operation_id="retrieve_job")
|
||||
def retrieve_job(
|
||||
job_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a job.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
|
||||
@@ -65,13 +65,13 @@ def retrieve_job(
|
||||
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
|
||||
def delete_job(
|
||||
job_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete a job by its job_id.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor)
|
||||
|
||||
@@ -15,13 +15,15 @@ router = APIRouter(prefix="/providers", tags=["providers"])
|
||||
def list_providers(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all custom providers in the database
|
||||
"""
|
||||
try:
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -32,13 +34,13 @@ def list_providers(
|
||||
@router.post("/", tags=["providers"], response_model=Provider, operation_id="create_provider")
|
||||
def create_provider(
|
||||
request: ProviderCreate = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new custom provider
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
provider = Provider(**request.model_dump())
|
||||
provider = server.provider_manager.create_provider(provider, actor=actor)
|
||||
@@ -48,25 +50,29 @@ def create_provider(
|
||||
@router.patch("/", tags=["providers"], response_model=Provider, operation_id="modify_provider")
|
||||
def modify_provider(
|
||||
request: ProviderUpdate = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update an existing custom provider
|
||||
"""
|
||||
provider = server.provider_manager.update_provider(request)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
provider = server.provider_manager.update_provider(request, actor=actor)
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete("/", tags=["providers"], response_model=None, operation_id="delete_provider")
|
||||
def delete_provider(
|
||||
provider_id: str = Query(..., description="The provider_id key to be deleted."),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete an existing custom provider
|
||||
"""
|
||||
try:
|
||||
server.provider_manager.delete_provider_by_id(provider_id=provider_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.provider_manager.delete_provider_by_id(provider_id=provider_id, actor=actor)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -18,12 +18,12 @@ router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
@router.get("/", response_model=List[Run], operation_id="list_runs")
|
||||
def list_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all runs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)]
|
||||
|
||||
@@ -31,12 +31,12 @@ def list_runs(
|
||||
@router.get("/active", response_model=List[Run], operation_id="list_active_runs")
|
||||
def list_active_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all active runs.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN)
|
||||
|
||||
@@ -46,13 +46,13 @@ def list_active_runs(
|
||||
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
|
||||
def retrieve_run(
|
||||
run_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
@@ -74,7 +74,7 @@ RunMessagesResponse = Annotated[
|
||||
async def list_run_messages(
|
||||
run_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
@@ -102,7 +102,7 @@ async def list_run_messages(
|
||||
if order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
messages = server.job_manager.get_run_messages(
|
||||
@@ -122,13 +122,13 @@ async def list_run_messages(
|
||||
@router.get("/{run_id}/usage", response_model=UsageStatistics, operation_id="retrieve_run_usage")
|
||||
def retrieve_run_usage(
|
||||
run_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get usage statistics for a run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
usage = server.job_manager.get_job_usage(job_id=run_id, actor=actor)
|
||||
@@ -140,13 +140,13 @@ def retrieve_run_usage(
|
||||
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
|
||||
def delete_run(
|
||||
run_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Delete a run by its run_id.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.delete_job_by_id(job_id=run_id, actor=actor)
|
||||
|
||||
@@ -25,9 +25,9 @@ logger = get_logger(__name__)
|
||||
def create_sandbox_config(
|
||||
config_create: SandboxConfigCreate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor)
|
||||
|
||||
@@ -35,18 +35,18 @@ def create_sandbox_config(
|
||||
@router.post("/e2b/default", response_model=PydanticSandboxConfig)
|
||||
def create_default_e2b_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor)
|
||||
|
||||
|
||||
@router.post("/local/default", response_model=PydanticSandboxConfig)
|
||||
def create_default_local_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ def create_default_local_sandbox_config(
|
||||
def create_custom_local_sandbox_config(
|
||||
local_sandbox_config: LocalSandboxConfig,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""
|
||||
Create or update a custom LocalSandboxConfig, including pip_requirements.
|
||||
@@ -67,7 +67,7 @@ def create_custom_local_sandbox_config(
|
||||
)
|
||||
|
||||
# Retrieve the user (actor)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Wrap the LocalSandboxConfig into a SandboxConfigCreate
|
||||
sandbox_config_create = SandboxConfigCreate(config=local_sandbox_config)
|
||||
@@ -83,9 +83,9 @@ def update_sandbox_config(
|
||||
sandbox_config_id: str,
|
||||
config_update: SandboxConfigUpdate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor)
|
||||
|
||||
|
||||
@@ -93,9 +93,9 @@ def update_sandbox_config(
|
||||
def delete_sandbox_config(
|
||||
sandbox_config_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor)
|
||||
|
||||
|
||||
@@ -105,22 +105,22 @@ def list_sandbox_configs(
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, after=after, sandbox_type=sandbox_type)
|
||||
|
||||
|
||||
@router.post("/local/recreate-venv", response_model=PydanticSandboxConfig)
|
||||
def force_recreate_local_sandbox_venv(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""
|
||||
Forcefully recreate the virtual environment for the local sandbox.
|
||||
Deletes and recreates the venv, then reinstalls required dependencies.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Retrieve the local sandbox config
|
||||
sbx_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
|
||||
@@ -162,9 +162,9 @@ def create_sandbox_env_var(
|
||||
sandbox_config_id: str,
|
||||
env_var_create: SandboxEnvironmentVariableCreate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor)
|
||||
|
||||
|
||||
@@ -173,9 +173,9 @@ def update_sandbox_env_var(
|
||||
env_var_id: str,
|
||||
env_var_update: SandboxEnvironmentVariableUpdate,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor)
|
||||
|
||||
|
||||
@@ -183,9 +183,9 @@ def update_sandbox_env_var(
|
||||
def delete_sandbox_env_var(
|
||||
env_var_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor)
|
||||
|
||||
|
||||
@@ -195,7 +195,7 @@ def list_sandbox_env_vars(
|
||||
limit: int = Query(1000, description="Number of results to return"),
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
actor_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, after=after)
|
||||
|
||||
@@ -23,12 +23,12 @@ router = APIRouter(prefix="/sources", tags=["sources"])
|
||||
def retrieve_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get all sources
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
if not source:
|
||||
@@ -40,12 +40,12 @@ def retrieve_source(
|
||||
def get_source_id_by_name(
|
||||
source_name: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a source by name
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||||
if not source:
|
||||
@@ -56,12 +56,12 @@ def get_source_id_by_name(
|
||||
@router.get("/", response_model=List[Source], operation_id="list_sources")
|
||||
def list_sources(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all data sources created by a user.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
return server.list_all_sources(actor=actor)
|
||||
|
||||
@@ -70,12 +70,12 @@ def list_sources(
|
||||
def create_source(
|
||||
source_create: SourceCreate,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
source = Source(**source_create.model_dump())
|
||||
|
||||
return server.source_manager.create_source(source=source, actor=actor)
|
||||
@@ -86,12 +86,12 @@ def modify_source(
|
||||
source_id: str,
|
||||
source: SourceUpdate,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the name or documentation of an existing data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
|
||||
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
|
||||
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
|
||||
@@ -101,12 +101,12 @@ def modify_source(
|
||||
def delete_source(
|
||||
source_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
server.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
@@ -117,12 +117,12 @@ def upload_file_to_source(
|
||||
source_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
@@ -151,12 +151,12 @@ def upload_file_to_source(
|
||||
def list_source_passages(
|
||||
source_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
|
||||
return passages
|
||||
|
||||
@@ -167,12 +167,12 @@ def list_source_files(
|
||||
limit: int = Query(1000, description="Number of files to return"),
|
||||
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List paginated files associated with a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
|
||||
|
||||
|
||||
@@ -183,12 +183,12 @@ def delete_file_from_source(
|
||||
source_id: str,
|
||||
file_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
|
||||
if deleted_file is None:
|
||||
|
||||
@@ -21,13 +21,13 @@ def list_steps(
|
||||
end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'),
|
||||
model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
List steps with optional pagination and date filters.
|
||||
Dates should be provided in ISO 8601 format (e.g. 2025-01-29T15:01:19-08:00)
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
# Convert ISO strings to datetime objects if provided
|
||||
start_dt = datetime.fromisoformat(start_date) if start_date else None
|
||||
@@ -48,14 +48,15 @@ def list_steps(
|
||||
@router.get("/{step_id}", response_model=Step, operation_id="retrieve_step")
|
||||
def retrieve_step(
|
||||
step_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Get a step by ID.
|
||||
"""
|
||||
try:
|
||||
return server.step_manager.get_step(step_id=step_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.step_manager.get_step(step_id=step_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
@@ -64,15 +65,15 @@ def retrieve_step(
|
||||
def update_step_transaction_id(
|
||||
step_id: str,
|
||||
transaction_id: str,
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update the transaction ID for a step.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.step_manager.update_step_transaction_id(actor, step_id=step_id, transaction_id=transaction_id)
|
||||
return server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Step not found")
|
||||
|
||||
@@ -17,11 +17,11 @@ def list_tags(
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
query_text: Optional[str] = Query(None),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get a list of all tags in the database
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tags = server.agent_manager.list_tags(actor=actor, after=after, limit=limit, query_text=query_text)
|
||||
return tags
|
||||
|
||||
@@ -29,12 +29,12 @@ logger = get_logger(__name__)
|
||||
def delete_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@@ -42,12 +42,12 @@ def delete_tool(
|
||||
def retrieve_tool(
|
||||
tool_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a tool by ID
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
@@ -61,13 +61,13 @@ def list_tools(
|
||||
limit: Optional[int] = 50,
|
||||
name: Optional[str] = None,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
if name is not None:
|
||||
tool = server.tool_manager.get_tool_by_name(tool_name=name, actor=actor)
|
||||
return [tool] if tool else []
|
||||
@@ -82,13 +82,13 @@ def list_tools(
|
||||
def create_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = Tool(**request.model_dump())
|
||||
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
|
||||
except UniqueConstraintViolationError as e:
|
||||
@@ -114,13 +114,13 @@ def create_tool(
|
||||
def upsert_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Create or update a tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
|
||||
return tool
|
||||
except UniqueConstraintViolationError as e:
|
||||
@@ -142,13 +142,13 @@ def modify_tool(
|
||||
tool_id: str,
|
||||
request: ToolUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
except LettaToolCreateError as e:
|
||||
# HTTP 400 == Bad Request
|
||||
@@ -163,12 +163,12 @@ def modify_tool(
|
||||
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
|
||||
def upsert_base_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Upsert base tools
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
|
||||
@@ -176,12 +176,12 @@ def upsert_base_tools(
|
||||
def run_tool_from_source(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: ToolRunFromSource = Body(...),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Attempt to build a tool from source, then run it on the provided arguments
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
return server.run_tool_from_source(
|
||||
@@ -227,12 +227,12 @@ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id:
|
||||
def list_composio_actions_by_app(
|
||||
composio_app_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Get a list of all Composio actions for a specific app
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
composio_api_key = get_composio_api_key(actor=actor, logger=logger)
|
||||
if not composio_api_key:
|
||||
raise HTTPException(
|
||||
@@ -246,12 +246,12 @@ def list_composio_actions_by_app(
|
||||
def add_composio_tool(
|
||||
composio_action_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
tool_create = ToolCreate.from_composio(action_name=composio_action_name)
|
||||
|
||||
@@ -25,15 +25,15 @@ class ProviderManager:
|
||||
provider.resolve_identifier()
|
||||
|
||||
new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True))
|
||||
new_provider.create(session)
|
||||
new_provider.create(session, actor=actor)
|
||||
return new_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider:
|
||||
def update_provider(self, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Update provider details."""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the existing provider by ID
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id)
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id, actor=actor)
|
||||
|
||||
# Update only the fields that are provided in ProviderUpdate
|
||||
update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
@@ -41,31 +41,32 @@ class ProviderManager:
|
||||
setattr(existing_provider, key, value)
|
||||
|
||||
# Commit the updated provider
|
||||
existing_provider.update(session)
|
||||
existing_provider.update(session, actor=actor)
|
||||
return existing_provider.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_provider_by_id(self, provider_id: str):
|
||||
def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
|
||||
"""Delete a provider."""
|
||||
with self.session_maker() as session:
|
||||
# Clear api key field
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id)
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
|
||||
existing_provider.api_key = None
|
||||
existing_provider.update(session)
|
||||
existing_provider.update(session, actor=actor)
|
||||
|
||||
# Soft delete in provider table
|
||||
existing_provider.delete(session)
|
||||
existing_provider.delete(session, actor=actor)
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]:
|
||||
def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]:
|
||||
"""List all providers with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
providers = ProviderModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
)
|
||||
return [provider.to_pydantic() for provider in providers]
|
||||
|
||||
|
||||
@@ -84,9 +84,9 @@ class StepManager:
|
||||
return new_step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_step(self, step_id: str) -> PydanticStep:
|
||||
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
with self.session_maker() as session:
|
||||
step = StepModel.read(db_session=session, identifier=step_id)
|
||||
step = StepModel.read(db_session=session, identifier=step_id, actor=actor)
|
||||
return step.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -1194,7 +1194,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||
for step_id in step_ids:
|
||||
step = server.step_manager.get_step(step_id=step_id)
|
||||
step = server.step_manager.get_step(step_id=step_id, actor=actor)
|
||||
assert step, "Step was not logged correctly"
|
||||
assert step.provider_id == provider.id
|
||||
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||
@@ -1208,7 +1208,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
assert prompt_tokens == usage.prompt_tokens
|
||||
assert total_tokens == usage.total_tokens
|
||||
|
||||
server.provider_manager.delete_provider_by_id(provider.id)
|
||||
server.provider_manager.delete_provider_by_id(provider.id, actor=actor)
|
||||
|
||||
existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor)
|
||||
|
||||
@@ -1221,7 +1221,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str):
|
||||
step_ids = set([msg.step_id for msg in get_messages_response])
|
||||
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
||||
for step_id in step_ids:
|
||||
step = server.step_manager.get_step(step_id=step_id)
|
||||
step = server.step_manager.get_step(step_id=step_id, actor=actor)
|
||||
assert step, "Step was not logged correctly"
|
||||
assert step.provider_id == None
|
||||
assert step.provider_name == agent.llm_config.model_endpoint_type
|
||||
|
||||
Reference in New Issue
Block a user