fix: patch out-of-sync / missing tzinfo timestamps coming back from API server (#1182)
This commit is contained in:
@@ -36,3 +36,4 @@ memgpt_version = 0.3.7
|
||||
|
||||
[client]
|
||||
anon_clientid = 00000000-0000-0000-0000-000000000000
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from memgpt.utils import (
|
||||
validate_function_response,
|
||||
verify_first_message_correctness,
|
||||
create_uuid_from_string,
|
||||
is_utc_datetime,
|
||||
)
|
||||
from memgpt.constants import (
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
@@ -140,7 +141,7 @@ def initialize_message_sequence(
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
memory_edit_timestamp: Optional[str] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
):
|
||||
) -> List[dict]:
|
||||
if memory_edit_timestamp is None:
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
@@ -291,6 +292,13 @@ class Agent(object):
|
||||
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"])
|
||||
self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None])
|
||||
|
||||
for m in self._messages:
|
||||
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
|
||||
# TODO eventually do casting via an edit_message function
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
else:
|
||||
# print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
|
||||
init_messages = initialize_message_sequence(
|
||||
@@ -309,6 +317,13 @@ class Agent(object):
|
||||
self.messages_total = 0
|
||||
self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
|
||||
|
||||
for m in self._messages:
|
||||
assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
|
||||
# TODO eventually do casting via an edit_message function
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
# Keep track of the total number of messages throughout all time
|
||||
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
|
||||
# self.messages_total_init = self.messages_total
|
||||
@@ -445,6 +460,8 @@ class Agent(object):
|
||||
|
||||
# role: assistant (requesting tool call, set tool call ID)
|
||||
messages.append(
|
||||
# NOTE: we're recreating the message here
|
||||
# TODO should probably just overwrite the fields?
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
@@ -710,7 +727,7 @@ class Agent(object):
|
||||
# (if yes) Step 3: call the function
|
||||
# (if yes) Step 4: send the info on the function call and function response to LLM
|
||||
response_message = response.choices[0].message
|
||||
response_message.copy()
|
||||
response_message.model_copy() # TODO why are we copying here?
|
||||
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message)
|
||||
|
||||
# Add the extra metadata to the assistant response
|
||||
|
||||
@@ -151,7 +151,7 @@ def get_db_model(
|
||||
metadata_ = Column(MutableJson)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
@@ -217,7 +217,7 @@ def get_db_model(
|
||||
embedding_model = Column(String)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
@@ -286,7 +286,7 @@ class RESTClient(AbstractClient):
|
||||
embedding_config=embedding_config,
|
||||
state=response.agent_state.state,
|
||||
# load datetime from timestampe
|
||||
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at),
|
||||
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc),
|
||||
)
|
||||
return agent_state
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Dict, TypeVar
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field, Json
|
||||
|
||||
from memgpt.constants import (
|
||||
DEFAULT_HUMAN,
|
||||
@@ -14,14 +15,9 @@ from memgpt.constants import (
|
||||
MAX_EMBEDDING_DIM,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
)
|
||||
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
|
||||
from memgpt.utils import get_utc_time, create_uuid_from_string
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
|
||||
from pydantic import BaseModel, Field, Json
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
|
||||
from pydantic import BaseModel, Field, Json
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd, is_utc_datetime
|
||||
|
||||
|
||||
class Record:
|
||||
@@ -136,6 +132,11 @@ class Message(Record):
|
||||
json_message = vars(self)
|
||||
if json_message["tool_calls"] is not None:
|
||||
json_message["tool_calls"] = [vars(tc) for tc in json_message["tool_calls"]]
|
||||
# turn datetime to ISO format
|
||||
# also if the created_at is missing a timezone, add UTC
|
||||
if not is_utc_datetime(self.created_at):
|
||||
self.created_at = self.created_at.replace(tzinfo=timezone.utc)
|
||||
json_message["created_at"] = self.created_at.isoformat()
|
||||
return json_message
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -22,7 +22,7 @@ def send_message(self: Agent, message: str) -> Optional[str]:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
# FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference
|
||||
self.interface.assistant_message(message, msg_obj=self._messages[-1])
|
||||
self.interface.assistant_message(message) # , msg_obj=self._messages[-1])
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import secrets
|
||||
from typing import Optional, List
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
|
||||
from memgpt.utils import get_local_time, enforce_types
|
||||
from memgpt.utils import enforce_types
|
||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
@@ -549,7 +549,7 @@ class MetadataStore:
|
||||
)
|
||||
for k, v in available_functions.items()
|
||||
]
|
||||
print(results)
|
||||
# print(results)
|
||||
return results
|
||||
# results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
# return [r.to_record() for r in results]
|
||||
|
||||
@@ -5,7 +5,7 @@ from memgpt.memory import (
|
||||
BaseRecallMemory,
|
||||
EmbeddingArchivalMemory,
|
||||
)
|
||||
from memgpt.utils import get_local_time, printd
|
||||
from memgpt.utils import printd
|
||||
from memgpt.data_types import Message, ToolCall, AgentState
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from asyncio import AbstractEventLoop
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import Message
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -33,6 +34,14 @@ class UserMessageRequest(BaseModel):
|
||||
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
|
||||
)
|
||||
|
||||
@validator("timestamp")
|
||||
def validate_timestamp(cls, value: Any) -> Any:
|
||||
if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
|
||||
raise ValueError("Timestamp must include timezone information.")
|
||||
if value.tzinfo.utcoffset(value) != datetime.fromtimestamp(timezone.utc).utcoffset():
|
||||
raise ValueError("Timestamp must be in UTC.")
|
||||
return value
|
||||
|
||||
|
||||
class UserMessageResponse(BaseModel):
|
||||
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
|
||||
@@ -90,6 +99,12 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
|
||||
[_, messages] = server.get_agent_recall_cursor(
|
||||
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
|
||||
)
|
||||
# print("====> messages-cursor DEBUG")
|
||||
# for i, msg in enumerate(messages):
|
||||
# print(f"message {i+1}/{len(messages)}")
|
||||
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
|
||||
# print(f"ISO format string: {msg['created_at']}")
|
||||
# print(msg)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytz
|
||||
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.utils import is_utc_datetime
|
||||
|
||||
|
||||
class QueuingInterface(AgentInterface):
|
||||
@@ -57,34 +58,54 @@ class QueuingInterface(AgentInterface):
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Handle reception of a user message"""
|
||||
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
||||
if self.debug:
|
||||
print(msg)
|
||||
print(vars(msg_obj))
|
||||
print(msg_obj.created_at.isoformat())
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
"""Handle the agent's internal monologue"""
|
||||
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
||||
if self.debug:
|
||||
print(msg)
|
||||
print(vars(msg_obj))
|
||||
print(msg_obj.created_at.isoformat())
|
||||
|
||||
new_message = {"internal_monologue": msg}
|
||||
|
||||
# add extra metadata
|
||||
if msg_obj is not None:
|
||||
new_message["id"] = str(msg_obj.id)
|
||||
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = msg_obj.created_at.isoformat()
|
||||
|
||||
self.buffer.put(new_message)
|
||||
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
"""Handle the agent sending a message"""
|
||||
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
||||
# assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
||||
|
||||
if self.debug:
|
||||
print(msg)
|
||||
if msg_obj is not None:
|
||||
print(vars(msg_obj))
|
||||
print(msg_obj.created_at.isoformat())
|
||||
|
||||
new_message = {"assistant_message": msg}
|
||||
|
||||
# add extra metadata
|
||||
if msg_obj is not None:
|
||||
new_message["id"] = str(msg_obj.id)
|
||||
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = msg_obj.created_at.isoformat()
|
||||
else:
|
||||
# FIXME this is a total hack
|
||||
assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message."
|
||||
# TODO also should not be accessing protected member here
|
||||
|
||||
new_message["id"] = self.buffer.queue[-1]["id"]
|
||||
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = self.buffer.queue[-1]["date"]
|
||||
|
||||
self.buffer.put(new_message)
|
||||
|
||||
@@ -95,6 +116,8 @@ class QueuingInterface(AgentInterface):
|
||||
|
||||
if self.debug:
|
||||
print(msg)
|
||||
print(vars(msg_obj))
|
||||
print(msg_obj.created_at.isoformat())
|
||||
|
||||
if msg.startswith("Running "):
|
||||
msg = msg.replace("Running ", "")
|
||||
@@ -121,6 +144,7 @@ class QueuingInterface(AgentInterface):
|
||||
# add extra metadata
|
||||
if msg_obj is not None:
|
||||
new_message["id"] = str(msg_obj.id)
|
||||
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = msg_obj.created_at.isoformat()
|
||||
|
||||
self.buffer.put(new_message)
|
||||
|
||||
@@ -1064,7 +1064,7 @@ class SyncServer(LockingServer):
|
||||
order_by: Optional[str] = "created_at",
|
||||
order: Optional[str] = "asc",
|
||||
reverse: Optional[bool] = False,
|
||||
):
|
||||
) -> Tuple[uuid.UUID, List[dict]]:
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import copy
|
||||
import re
|
||||
import json
|
||||
@@ -469,6 +469,10 @@ NOUN_BANK = [
|
||||
]
|
||||
|
||||
|
||||
def is_utc_datetime(dt: datetime) -> bool:
|
||||
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0)
|
||||
|
||||
|
||||
def get_tool_call_id() -> str:
|
||||
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user