diff --git a/memgpt/server/rest_api/agents/message.py b/memgpt/server/rest_api/agents/message.py index 8b1f3a15..7b514eda 100644 --- a/memgpt/server/rest_api/agents/message.py +++ b/memgpt/server/rest_api/agents/message.py @@ -4,7 +4,7 @@ import uuid from asyncio import AbstractEventLoop from enum import Enum from functools import partial -from typing import List +from typing import List, Optional from fastapi import APIRouter, Body, HTTPException, Query, Depends from pydantic import BaseModel, Field @@ -38,6 +38,11 @@ class GetAgentMessagesRequest(BaseModel): count: int = Field(..., description="How many messages to retrieve.") +class GetAgentMessagesCursorRequest(BaseModel): + before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.") + limit: int = Field(..., description="Maximum number of messages to retrieve.") + + class GetAgentMessagesResponse(BaseModel): messages: list = Field(..., description="List of message objects.") @@ -63,6 +68,25 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count) return GetAgentMessagesResponse(messages=messages) + @router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse) + def get_agent_messages_cursor( + agent_id: uuid.UUID, + before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."), + limit: int = Query(10, description="Maximum number of messages to retrieve."), + user_id: uuid.UUID = Depends(get_current_user_with_server), + ): + """ + Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate. + """ + # Validate with the Pydantic model (optional) + request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit) + + interface.clear() + [_, messages] = server.get_agent_recall_cursor( + user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True + ) + return GetAgentMessagesResponse(messages=messages) + @router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse) async def send_message( agent_id: uuid.UUID, diff --git a/memgpt/server/server.py b/memgpt/server/server.py index be5bfe44..c9885560 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1,5 +1,4 @@ import json -from datetime import datetime import logging import uuid from abc import abstractmethod @@ -958,8 +957,8 @@ class SyncServer(LockingServer): cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor( after=after, before=before, limit=limit, order_by=order_by, reverse=reverse ) - json_records = [vars(record) for record in records] + json_records = [record.to_json() for record in records] # TODO: mark what is in-context versus not return cursor, json_records