fix: patch messages route + unify all the api/agents API routes to use {agent_id} via path parameter (#1129)
Co-authored-by: Robin Goetz <35136007+goetzrobin@users.noreply.github.com>
This commit is contained in:
@@ -124,6 +124,12 @@ class Message(Record):
|
||||
assert tool_call_id is None
|
||||
self.tool_call_id = tool_call_id
|
||||
|
||||
def to_json(self):
|
||||
json_message = vars(self)
|
||||
if json_message["tool_calls"] is not None:
|
||||
json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]]
|
||||
return json_message
|
||||
|
||||
@staticmethod
|
||||
def dict_to_message(
|
||||
user_id: uuid.UUID,
|
||||
|
||||
@@ -12,7 +12,6 @@ router = APIRouter()
|
||||
|
||||
|
||||
class CommandRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.")
|
||||
command: str = Field(..., description="The command to be executed by the agent.")
|
||||
|
||||
|
||||
@@ -23,8 +22,12 @@ class CommandResponse(BaseModel):
|
||||
def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.post("/agents/command", tags=["agents"], response_model=CommandResponse)
|
||||
def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server)):
|
||||
@router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse)
|
||||
def run_command(
|
||||
agent_id: uuid.UUID,
|
||||
request: CommandRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Execute a command on a specified agent.
|
||||
|
||||
@@ -34,7 +37,7 @@ def setup_agents_command_router(server: SyncServer, interface: QueuingInterface,
|
||||
"""
|
||||
interface.clear()
|
||||
try:
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -15,12 +15,7 @@ from memgpt.server.server import SyncServer
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class GetAgentRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
|
||||
|
||||
|
||||
class AgentRenameRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
|
||||
agent_name: str = Field(..., description="New name for the agent.")
|
||||
|
||||
|
||||
@@ -51,9 +46,9 @@ def validate_agent_name(name: str) -> str:
|
||||
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/config", tags=["agents"], response_model=GetAgentResponse)
|
||||
@router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse)
|
||||
def get_agent_config(
|
||||
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
|
||||
agent_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
@@ -61,9 +56,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
||||
"""
|
||||
request = GetAgentRequest(agent_id=agent_id)
|
||||
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
|
||||
interface.clear()
|
||||
@@ -90,8 +83,9 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
sources=attached_sources,
|
||||
)
|
||||
|
||||
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
|
||||
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
|
||||
def update_agent_name(
|
||||
agent_id: uuid.UUID,
|
||||
request: AgentRenameRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
@@ -100,7 +94,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
This changes the name of the agent in the database but does NOT edit the agent's persona.
|
||||
"""
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
|
||||
valid_name = validate_agent_name(request.agent_name)
|
||||
|
||||
@@ -115,13 +109,13 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
@router.delete("/agents/{agent_id}", tags=["agents"])
|
||||
def delete_agent(
|
||||
agent_id,
|
||||
agent_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
agent_id = uuid.UUID(agent_id)
|
||||
# agent_id = uuid.UUID(agent_id)
|
||||
|
||||
interface.clear()
|
||||
try:
|
||||
|
||||
@@ -26,7 +26,6 @@ class GetAgentMemoryResponse(BaseModel):
|
||||
|
||||
# NOTE not subclassing CoreMemory since in the request both field are optional
|
||||
class UpdateAgentMemoryRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
human: str = Field(None, description="Human element of the core memory.")
|
||||
persona: str = Field(None, description="Persona element of the core memory.")
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ class MessageRoleType(str, Enum):
|
||||
|
||||
|
||||
class UserMessageRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
message: str = Field(..., description="The message content to be processed by the agent.")
|
||||
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
|
||||
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
||||
@@ -35,7 +34,6 @@ class UserMessageResponse(BaseModel):
|
||||
|
||||
|
||||
class GetAgentMessagesRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="The unique identifier of the agent.")
|
||||
start: int = Field(..., description="Message index to start on (reverse chronological).")
|
||||
count: int = Field(..., description="How many messages to retrieve.")
|
||||
|
||||
@@ -47,9 +45,9 @@ class GetAgentMessagesResponse(BaseModel):
|
||||
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse)
|
||||
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
|
||||
def get_agent_messages(
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent."),
|
||||
agent_id: uuid.UUID,
|
||||
start: int = Query(..., description="Message index to start on (reverse chronological)."),
|
||||
count: int = Query(..., description="How many messages to retrieve."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
@@ -59,14 +57,15 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
|
||||
"""
|
||||
# Validate with the Pydantic model (optional)
|
||||
request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
|
||||
interface.clear()
|
||||
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
@router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse)
|
||||
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
|
||||
async def send_message(
|
||||
agent_id: uuid.UUID,
|
||||
request: UserMessageRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
@@ -76,7 +75,7 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It can optionally stream the response if 'stream' is set to True.
|
||||
"""
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
|
||||
if request.role == "user" or request.role is None:
|
||||
message_func = server.user_message
|
||||
|
||||
@@ -827,7 +827,7 @@ class SyncServer(LockingServer):
|
||||
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
# convert to json
|
||||
json_messages = [vars(record) for record in messages]
|
||||
json_messages = [record.to_json() for record in messages]
|
||||
return json_messages
|
||||
|
||||
def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||
|
||||
Reference in New Issue
Block a user