feat: expose cursor based retrieval of previous messages (#1140)
This commit is contained in:
@@ -4,7 +4,7 @@ import uuid
|
||||
from asyncio import AbstractEventLoop
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -38,6 +38,11 @@ class GetAgentMessagesRequest(BaseModel):
|
||||
count: int = Field(..., description="How many messages to retrieve.")
|
||||
|
||||
|
||||
class GetAgentMessagesCursorRequest(BaseModel):
|
||||
before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.")
|
||||
limit: int = Field(..., description="Maximum number of messages to retrieve.")
|
||||
|
||||
|
||||
class GetAgentMessagesResponse(BaseModel):
|
||||
messages: list = Field(..., description="List of message objects.")
|
||||
|
||||
@@ -63,6 +68,25 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
|
||||
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
@router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse)
|
||||
def get_agent_messages_cursor(
|
||||
agent_id: uuid.UUID,
|
||||
before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
|
||||
"""
|
||||
# Validate with the Pydantic model (optional)
|
||||
request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit)
|
||||
|
||||
interface.clear()
|
||||
[_, messages] = server.get_agent_recall_cursor(
|
||||
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
|
||||
)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
|
||||
async def send_message(
|
||||
agent_id: uuid.UUID,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
@@ -958,8 +957,8 @@ class SyncServer(LockingServer):
|
||||
cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
)
|
||||
json_records = [vars(record) for record in records]
|
||||
|
||||
json_records = [record.to_json() for record in records]
|
||||
# TODO: mark what is in-context versus not
|
||||
return cursor, json_records
|
||||
|
||||
|
||||
Reference in New Issue
Block a user