feat: expose cursor based retrieval of previous messages (#1140)

This commit is contained in:
Robin Goetz
2024-03-12 12:38:01 -04:00
committed by GitHub
parent 3f998148ce
commit 08124bdfa3
2 changed files with 26 additions and 3 deletions

View File

@@ -4,7 +4,7 @@ import uuid
from asyncio import AbstractEventLoop
from enum import Enum
from functools import partial
from typing import List
from typing import List, Optional
from fastapi import APIRouter, Body, HTTPException, Query, Depends
from pydantic import BaseModel, Field
@@ -38,6 +38,11 @@ class GetAgentMessagesRequest(BaseModel):
count: int = Field(..., description="How many messages to retrieve.")
class GetAgentMessagesCursorRequest(BaseModel):
before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.")
limit: int = Field(..., description="Maximum number of messages to retrieve.")
class GetAgentMessagesResponse(BaseModel):
messages: list = Field(..., description="List of message objects.")
@@ -63,6 +68,25 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
return GetAgentMessagesResponse(messages=messages)
@router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages_cursor(
agent_id: uuid.UUID,
before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
# Validate with the Pydantic model (optional)
request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit)
interface.clear()
[_, messages] = server.get_agent_recall_cursor(
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
async def send_message(
agent_id: uuid.UUID,

View File

@@ -1,5 +1,4 @@
import json
from datetime import datetime
import logging
import uuid
from abc import abstractmethod
@@ -958,8 +957,8 @@ class SyncServer(LockingServer):
cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor(
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
)
json_records = [vars(record) for record in records]
json_records = [record.to_json() for record in records]
# TODO: mark what is in-context versus not
return cursor, json_records