feat: Factor out custom_columns serialize/deserialize logic (#1028)

This commit is contained in:
Matthew Zhou
2025-02-17 15:07:12 -08:00
committed by GitHub
parent a1cc16dd5a
commit 3dc1767f46
2 changed files with 190 additions and 117 deletions

152
letta/helpers/converters.py Normal file
View File

@@ -0,0 +1,152 @@
import base64
from typing import Any, Dict, List, Optional, Union
import numpy as np
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy import Dialect
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule, ToolRule
# --------------------------
# LLMConfig Serialization
# --------------------------
def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]:
"""Convert an LLMConfig object into a JSON-serializable dictionary."""
if config and isinstance(config, LLMConfig):
return config.model_dump()
return config
def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]:
"""Convert a dictionary back into an LLMConfig object."""
return LLMConfig(**data) if data else None
# --------------------------
# EmbeddingConfig Serialization
# --------------------------
def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]:
"""Convert an EmbeddingConfig object into a JSON-serializable dictionary."""
if config and isinstance(config, EmbeddingConfig):
return config.model_dump()
return config
def deserialize_embedding_config(data: Optional[Dict]) -> Optional[EmbeddingConfig]:
"""Convert a dictionary back into an EmbeddingConfig object."""
return EmbeddingConfig(**data) if data else None
# --------------------------
# ToolRule Serialization
# --------------------------
def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str, Any]]:
"""Convert a list of ToolRules into a JSON-serializable format."""
if not tool_rules:
return []
data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility
# Validate ToolRule structure
for rule_data in data:
if rule_data["type"] == ToolRuleType.constrain_child_tools.value and "children" not in rule_data:
raise ValueError(f"Invalid ToolRule serialization: 'children' field missing for rule {rule_data}")
return data
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]:
"""Convert a list of dictionaries back into ToolRule objects."""
if not data:
return []
return [deserialize_tool_rule(rule_data) for rule_data in data]
def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
rule_type = ToolRuleType(data.get("type"))
if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
return ChildToolRule(**data)
elif rule_type == ToolRuleType.conditional:
return ConditionalToolRule(**data)
raise ValueError(f"Unknown ToolRule type: {rule_type}")
# --------------------------
# ToolCall Serialization
# --------------------------
def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]) -> List[Dict]:
"""Convert a list of OpenAI ToolCall objects into JSON-serializable format."""
if not tool_calls:
return []
serialized_calls = []
for call in tool_calls:
if isinstance(call, OpenAIToolCall):
serialized_calls.append(call.model_dump())
elif isinstance(call, dict):
serialized_calls.append(call) # Already a dictionary, leave it as-is
else:
raise TypeError(f"Unexpected tool call type: {type(call)}")
return serialized_calls
def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]:
"""Convert a JSON list back into OpenAIToolCall objects."""
if not data:
return []
calls = []
for item in data:
func_data = item.pop("function", None)
tool_call_function = OpenAIFunction(**func_data) if func_data else None
calls.append(OpenAIToolCall(function=tool_call_function, **item))
return calls
# --------------------------
# Vector Serialization
# --------------------------
def serialize_vector(vector: Optional[Union[List[float], np.ndarray]]) -> Optional[bytes]:
"""Convert a NumPy array or list into a base64-encoded byte string."""
if vector is None:
return None
if isinstance(vector, list):
vector = np.array(vector, dtype=np.float32)
return base64.b64encode(vector.tobytes())
def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.ndarray]:
"""Convert a base64-encoded byte string back into a NumPy array."""
if not data:
return None
if dialect.name == "sqlite":
data = base64.b64decode(data)
return np.frombuffer(data, dtype=np.float32)

View File

@@ -1,159 +1,80 @@
import base64
from typing import List, Union
import numpy as np
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy import JSON
from sqlalchemy.types import BINARY, TypeDecorator
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
class EmbeddingConfigColumn(TypeDecorator):
"""Custom type for storing EmbeddingConfig as JSON."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value and isinstance(value, EmbeddingConfig):
return value.model_dump()
return value
def process_result_value(self, value, dialect):
if value:
return EmbeddingConfig(**value)
return value
from letta.helpers.converters import (
deserialize_embedding_config,
deserialize_llm_config,
deserialize_tool_calls,
deserialize_tool_rules,
deserialize_vector,
serialize_embedding_config,
serialize_llm_config,
serialize_tool_calls,
serialize_tool_rules,
serialize_vector,
)
class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON."""
"""Custom SQLAlchemy column type for storing LLMConfig as JSON."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value and isinstance(value, LLMConfig):
return value.model_dump()
return value
return serialize_llm_config(value)
def process_result_value(self, value, dialect):
if value:
return LLMConfig(**value)
return value
return deserialize_llm_config(value)
class EmbeddingConfigColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing EmbeddingConfig as JSON."""
impl = JSON
cache_ok = True
def process_bind_param(self, value, dialect):
return serialize_embedding_config(value)
def process_result_value(self, value, dialect):
return deserialize_embedding_config(value)
class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value
return serialize_tool_rules(value)
for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value
@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first or rule_type == "InitToolRule":
data["type"] = ToolRuleType.run_first
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop or rule_type == "TerminalToolRule":
data["type"] = ToolRuleType.exit_loop
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools or rule_type == "ToolRule":
data["type"] = ToolRuleType.constrain_child_tools
rule = ChildToolRule(**data)
return rule
elif rule_type == ToolRuleType.conditional:
rule = ConditionalToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")
def process_result_value(self, value, dialect):
return deserialize_tool_rules(value)
class ToolCallColumn(TypeDecorator):
"""Custom SQLAlchemy column type for storing OpenAI ToolCall objects as JSON."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, OpenAIToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values
return value
return serialize_tool_calls(value)
def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = OpenAIFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(OpenAIToolCall(function=tool_call_function, **tool_value))
return tools
return value
return deserialize_tool_calls(value)
class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""
"""Custom SQLAlchemy column type for storing vectors in SQLite."""
impl = BINARY
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(BINARY())
def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, list):
value = np.array(value, dtype=np.float32)
return base64.b64encode(value.tobytes())
return serialize_vector(value)
def process_result_value(self, value, dialect):
if not value:
return value
if dialect.name == "sqlite":
value = base64.b64decode(value)
return np.frombuffer(value, dtype=np.float32)
return deserialize_vector(value, dialect)