fix: patch out-of-sync / missing tzinfo timestamps coming back from API server (#1182)

This commit is contained in:
Charles Packer
2024-03-26 20:37:44 -07:00
committed by GitHub
parent c0bd66c957
commit 4b5666ac64
12 changed files with 85 additions and 23 deletions

View File

@@ -36,3 +36,4 @@ memgpt_version = 0.3.7
[client]
anon_clientid = 00000000-0000-0000-0000-000000000000

View File

@@ -29,6 +29,7 @@ from memgpt.utils import (
validate_function_response,
verify_first_message_correctness,
create_uuid_from_string,
is_utc_datetime,
)
from memgpt.constants import (
FIRST_MESSAGE_ATTEMPTS,
@@ -140,7 +141,7 @@ def initialize_message_sequence(
recall_memory: Optional[RecallMemory] = None,
memory_edit_timestamp: Optional[str] = None,
include_initial_boot_message: bool = True,
):
) -> List[dict]:
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
@@ -291,6 +292,13 @@ class Agent(object):
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"])
self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None])
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
else:
# print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
init_messages = initialize_message_sequence(
@@ -309,6 +317,13 @@ class Agent(object):
self.messages_total = 0
self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
for m in self._messages:
assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
# self.messages_total_init = self.messages_total
@@ -445,6 +460,8 @@ class Agent(object):
# role: assistant (requesting tool call, set tool call ID)
messages.append(
# NOTE: we're recreating the message here
# TODO should probably just overwrite the fields?
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
@@ -710,7 +727,7 @@ class Agent(object):
# (if yes) Step 3: call the function
# (if yes) Step 4: send the info on the function call and function response to LLM
response_message = response.choices[0].message
response_message.copy()
response_message.model_copy() # TODO why are we copying here?
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message)
# Add the extra metadata to the assistant response

View File

@@ -151,7 +151,7 @@ def get_db_model(
metadata_ = Column(MutableJson)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
@@ -217,7 +217,7 @@ def get_db_model(
embedding_model = Column(String)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True), server_default=func.now())
created_at = Column(DateTime(timezone=True))
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"

View File

@@ -286,7 +286,7 @@ class RESTClient(AbstractClient):
embedding_config=embedding_config,
state=response.agent_state.state,
# load datetime from timestampe
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at),
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc),
)
return agent_state

View File

@@ -1,9 +1,10 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import Optional, List, Dict, TypeVar
import numpy as np
from pydantic import BaseModel, Field, Json
from memgpt.constants import (
DEFAULT_HUMAN,
@@ -14,14 +15,9 @@ from memgpt.constants import (
MAX_EMBEDDING_DIM,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.utils import get_utc_time, create_uuid_from_string
from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd
from pydantic import BaseModel, Field, Json
from memgpt.utils import get_human_text, get_persona_text, printd
from pydantic import BaseModel, Field, Json
from memgpt.utils import get_human_text, get_persona_text, printd, is_utc_datetime
class Record:
@@ -136,6 +132,11 @@ class Message(Record):
json_message = vars(self)
if json_message["tool_calls"] is not None:
json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]]
# turn datetime to ISO format
# also if the created_at is missing a timezone, add UTC
if not is_utc_datetime(self.created_at):
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
json_message["created_at"] = self.created_at.isoformat()
return json_message
@staticmethod

View File

@@ -22,7 +22,7 @@ def send_message(self: Agent, message: str) -> Optional[str]:
Optional[str]: None is always returned as this function does not produce a response.
"""
# FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference
self.interface.assistant_message(message, msg_obj=self._messages[-1])
self.interface.assistant_message(message) # , msg_obj=self._messages[-1])
return None

View File

@@ -7,7 +7,7 @@ import secrets
from typing import Optional, List
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
from memgpt.utils import get_local_time, enforce_types
from memgpt.utils import enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.functions.functions import load_all_function_sets
@@ -549,7 +549,7 @@ class MetadataStore:
)
for k, v in available_functions.items()
]
print(results)
# print(results)
return results
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
# return [r.to_record() for r in results]

View File

@@ -5,7 +5,7 @@ from memgpt.memory import (
BaseRecallMemory,
EmbeddingArchivalMemory,
)
from memgpt.utils import get_local_time, printd
from memgpt.utils import printd
from memgpt.data_types import Message, ToolCall, AgentState
from datetime import datetime

View File

@@ -1,20 +1,21 @@
import asyncio
import json
import uuid
from datetime import datetime
from datetime import datetime, timezone
from asyncio import AbstractEventLoop
from enum import Enum
from functools import partial
from typing import List, Optional
from typing import List, Optional, Any
from fastapi import APIRouter, Body, HTTPException, Query, Depends
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from starlette.responses import StreamingResponse
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.data_types import Message
router = APIRouter()
@@ -33,6 +34,14 @@ class UserMessageRequest(BaseModel):
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
)
@validator("timestamp")
def validate_timestamp(cls, value: Any) -> Any:
if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
raise ValueError("Timestamp must include timezone information.")
if value.tzinfo.utcoffset(value) != datetime.fromtimestamp(timezone.utc).utcoffset():
raise ValueError("Timestamp must be in UTC.")
return value
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
@@ -90,6 +99,12 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
[_, messages] = server.get_agent_recall_cursor(
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
)
# print("====> messages-cursor DEBUG")
# for i, msg in enumerate(messages):
# print(f"message {i+1}/{len(messages)}")
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
# print(f"ISO format string: {msg['created_at']}")
# print(msg)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)

View File

@@ -7,6 +7,7 @@ import pytz
from memgpt.interface import AgentInterface
from memgpt.data_types import Message
from memgpt.utils import is_utc_datetime
class QueuingInterface(AgentInterface):
@@ -57,34 +58,54 @@ class QueuingInterface(AgentInterface):
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Handle reception of a user message"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent's internal monologue"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
new_message = {"internal_monologue": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
self.buffer.put(new_message)
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent sending a message"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
# assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
if msg_obj is not None:
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
new_message = {"assistant_message": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
else:
# FIXME this is a total hack
assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message."
# TODO also should not be accessing protected member here
new_message["id"] = self.buffer.queue[-1]["id"]
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = self.buffer.queue[-1]["date"]
self.buffer.put(new_message)
@@ -95,6 +116,8 @@ class QueuingInterface(AgentInterface):
if self.debug:
print(msg)
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
if msg.startswith("Running "):
msg = msg.replace("Running ", "")
@@ -121,6 +144,7 @@ class QueuingInterface(AgentInterface):
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
self.buffer.put(new_message)

View File

@@ -1064,7 +1064,7 @@ class SyncServer(LockingServer):
order_by: Optional[str] = "created_at",
order: Optional[str] = "asc",
reverse: Optional[bool] = False,
):
) -> Tuple[uuid.UUID, List[dict]]:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:

View File

@@ -1,4 +1,4 @@
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta
import copy
import re
import json
@@ -469,6 +469,10 @@ NOUN_BANK = [
]
def is_utc_datetime(dt: datetime) -> bool:
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0)
def get_tool_call_id() -> str:
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]