fix: patch messages route + unify all the api/agents API routes to use {agent_id} via path parameter (#1129)

Co-authored-by: Robin Goetz <35136007+goetzrobin@users.noreply.github.com>
This commit is contained in:
Charles Packer
2024-03-11 14:30:58 -07:00
committed by GitHub
parent e478c09ad3
commit 6dc041711c
6 changed files with 28 additions and 27 deletions

View File

@@ -124,6 +124,12 @@ class Message(Record):
assert tool_call_id is None
self.tool_call_id = tool_call_id
def to_json(self):
json_message = vars(self)
if json_message["tool_calls"] is not None:
json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]]
return json_message
@staticmethod
def dict_to_message(
user_id: uuid.UUID,

View File

@@ -12,7 +12,6 @@ router = APIRouter()
class CommandRequest(BaseModel):
agent_id: str = Field(..., description="Identifier of the agent on which the command will be executed.")
command: str = Field(..., description="The command to be executed by the agent.")
@@ -23,8 +22,12 @@ class CommandResponse(BaseModel):
def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.post("/agents/command", tags=["agents"], response_model=CommandResponse)
def run_command(request: CommandRequest = Body(...), user_id: uuid.UUID = Depends(get_current_user_with_server)):
@router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse)
def run_command(
agent_id: uuid.UUID,
request: CommandRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Execute a command on a specified agent.
@@ -34,7 +37,7 @@ def setup_agents_command_router(server: SyncServer, interface: QueuingInterface,
"""
interface.clear()
try:
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command)
except HTTPException:
raise

View File

@@ -15,12 +15,7 @@ from memgpt.server.server import SyncServer
router = APIRouter()
class GetAgentRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
class AgentRenameRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
agent_name: str = Field(..., description="New name for the agent.")
@@ -51,9 +46,9 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/config", tags=["agents"], response_model=GetAgentResponse)
@router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
@@ -61,9 +56,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
request = GetAgentRequest(agent_id=agent_id)
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
attached_sources = server.list_attached_sources(agent_id=agent_id)
interface.clear()
@@ -90,8 +83,9 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
sources=attached_sources,
)
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
@@ -100,7 +94,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
valid_name = validate_agent_name(request.agent_name)
@@ -115,13 +109,13 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id,
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
agent_id = uuid.UUID(agent_id)
# agent_id = uuid.UUID(agent_id)
interface.clear()
try:

View File

@@ -26,7 +26,6 @@ class GetAgentMemoryResponse(BaseModel):
# NOTE not subclassing CoreMemory since in the request both field are optional
class UpdateAgentMemoryRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")
human: str = Field(None, description="Human element of the core memory.")
persona: str = Field(None, description="Persona element of the core memory.")

View File

@@ -24,7 +24,6 @@ class MessageRoleType(str, Enum):
class UserMessageRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")
message: str = Field(..., description="The message content to be processed by the agent.")
stream: bool = Field(default=False, description="Flag to determine if the response should be streamed. Set to True for streaming.")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
@@ -35,7 +34,6 @@ class UserMessageResponse(BaseModel):
class GetAgentMessagesRequest(BaseModel):
agent_id: str = Field(..., description="The unique identifier of the agent.")
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
@@ -47,9 +45,9 @@ class GetAgentMessagesResponse(BaseModel):
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/message", tags=["agents"], response_model=GetAgentMessagesResponse)
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages(
agent_id: str = Query(..., description="The unique identifier of the agent."),
agent_id: uuid.UUID,
start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
@@ -59,14 +57,15 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
"""
# Validate with the Pydantic model (optional)
request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
interface.clear()
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/message", tags=["agents"], response_model=UserMessageResponse)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
async def send_message(
agent_id: uuid.UUID,
request: UserMessageRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
@@ -76,7 +75,7 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True.
"""
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
if request.role == "user" or request.role is None:
message_func = server.user_message

View File

@@ -827,7 +827,7 @@ class SyncServer(LockingServer):
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
# convert to json
json_messages = [vars(record) for record in messages]
json_messages = [record.to_json() for record in messages]
return json_messages
def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list: