feat: Factor out custom_columns serialize/deserialize logic (#1028)
This commit is contained in:
152
letta/helpers/converters.py
Normal file
152
letta/helpers/converters.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy import Dialect
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule, ToolRule
|
||||
|
||||
# --------------------------
|
||||
# LLMConfig Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]:
|
||||
"""Convert an LLMConfig object into a JSON-serializable dictionary."""
|
||||
if config and isinstance(config, LLMConfig):
|
||||
return config.model_dump()
|
||||
return config
|
||||
|
||||
|
||||
def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]:
|
||||
"""Convert a dictionary back into an LLMConfig object."""
|
||||
return LLMConfig(**data) if data else None
|
||||
|
||||
|
||||
# --------------------------
|
||||
# EmbeddingConfig Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]:
|
||||
"""Convert an EmbeddingConfig object into a JSON-serializable dictionary."""
|
||||
if config and isinstance(config, EmbeddingConfig):
|
||||
return config.model_dump()
|
||||
return config
|
||||
|
||||
|
||||
def deserialize_embedding_config(data: Optional[Dict]) -> Optional[EmbeddingConfig]:
|
||||
"""Convert a dictionary back into an EmbeddingConfig object."""
|
||||
return EmbeddingConfig(**data) if data else None
|
||||
|
||||
|
||||
# --------------------------
|
||||
# ToolRule Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str, Any]]:
|
||||
"""Convert a list of ToolRules into a JSON-serializable format."""
|
||||
|
||||
if not tool_rules:
|
||||
return []
|
||||
|
||||
data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility
|
||||
|
||||
# Validate ToolRule structure
|
||||
for rule_data in data:
|
||||
if rule_data["type"] == ToolRuleType.constrain_child_tools.value and "children" not in rule_data:
|
||||
raise ValueError(f"Invalid ToolRule serialization: 'children' field missing for rule {rule_data}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]:
|
||||
"""Convert a list of dictionaries back into ToolRule objects."""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
return [deserialize_tool_rule(rule_data) for rule_data in data]
|
||||
|
||||
|
||||
def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type"))
|
||||
|
||||
if rule_type == ToolRuleType.run_first:
|
||||
return InitToolRule(**data)
|
||||
elif rule_type == ToolRuleType.exit_loop:
|
||||
return TerminalToolRule(**data)
|
||||
elif rule_type == ToolRuleType.constrain_child_tools:
|
||||
return ChildToolRule(**data)
|
||||
elif rule_type == ToolRuleType.conditional:
|
||||
return ConditionalToolRule(**data)
|
||||
|
||||
raise ValueError(f"Unknown ToolRule type: {rule_type}")
|
||||
|
||||
|
||||
# --------------------------
|
||||
# ToolCall Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]) -> List[Dict]:
|
||||
"""Convert a list of OpenAI ToolCall objects into JSON-serializable format."""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
serialized_calls = []
|
||||
for call in tool_calls:
|
||||
if isinstance(call, OpenAIToolCall):
|
||||
serialized_calls.append(call.model_dump())
|
||||
elif isinstance(call, dict):
|
||||
serialized_calls.append(call) # Already a dictionary, leave it as-is
|
||||
else:
|
||||
raise TypeError(f"Unexpected tool call type: {type(call)}")
|
||||
|
||||
return serialized_calls
|
||||
|
||||
|
||||
def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]:
|
||||
"""Convert a JSON list back into OpenAIToolCall objects."""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
calls = []
|
||||
for item in data:
|
||||
func_data = item.pop("function", None)
|
||||
tool_call_function = OpenAIFunction(**func_data) if func_data else None
|
||||
calls.append(OpenAIToolCall(function=tool_call_function, **item))
|
||||
|
||||
return calls
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Vector Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_vector(vector: Optional[Union[List[float], np.ndarray]]) -> Optional[bytes]:
|
||||
"""Convert a NumPy array or list into a base64-encoded byte string."""
|
||||
if vector is None:
|
||||
return None
|
||||
if isinstance(vector, list):
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
|
||||
return base64.b64encode(vector.tobytes())
|
||||
|
||||
|
||||
def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.ndarray]:
|
||||
"""Convert a base64-encoded byte string back into a NumPy array."""
|
||||
if not data:
|
||||
return None
|
||||
|
||||
if dialect.name == "sqlite":
|
||||
data = base64.b64decode(data)
|
||||
|
||||
return np.frombuffer(data, dtype=np.float32)
|
||||
@@ -1,159 +1,80 @@
|
||||
import base64
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy import JSON
|
||||
from sqlalchemy.types import BINARY, TypeDecorator
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||
|
||||
|
||||
class EmbeddingConfigColumn(TypeDecorator):
|
||||
"""Custom type for storing EmbeddingConfig as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value and isinstance(value, EmbeddingConfig):
|
||||
return value.model_dump()
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return EmbeddingConfig(**value)
|
||||
return value
|
||||
from letta.helpers.converters import (
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_tool_calls,
|
||||
deserialize_tool_rules,
|
||||
deserialize_vector,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_tool_calls,
|
||||
serialize_tool_rules,
|
||||
serialize_vector,
|
||||
)
|
||||
|
||||
|
||||
class LLMConfigColumn(TypeDecorator):
|
||||
"""Custom type for storing LLMConfig as JSON."""
|
||||
"""Custom SQLAlchemy column type for storing LLMConfig as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value and isinstance(value, LLMConfig):
|
||||
return value.model_dump()
|
||||
return value
|
||||
return serialize_llm_config(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return LLMConfig(**value)
|
||||
return value
|
||||
return deserialize_llm_config(value)
|
||||
|
||||
|
||||
class EmbeddingConfigColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing EmbeddingConfig as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_embedding_config(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_embedding_config(value)
|
||||
|
||||
|
||||
class ToolRulesColumn(TypeDecorator):
|
||||
"""Custom type for storing a list of ToolRules as JSON"""
|
||||
"""Custom SQLAlchemy column type for storing a list of ToolRules as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Convert a list of ToolRules to JSON-serializable format."""
|
||||
if value:
|
||||
data = [rule.model_dump() for rule in value]
|
||||
for d in data:
|
||||
d["type"] = d["type"].value
|
||||
return serialize_tool_rules(value)
|
||||
|
||||
for d in data:
|
||||
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
|
||||
return data
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
|
||||
"""Convert JSON back to a list of ToolRules."""
|
||||
if value:
|
||||
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
|
||||
if rule_type == ToolRuleType.run_first or rule_type == "InitToolRule":
|
||||
data["type"] = ToolRuleType.run_first
|
||||
return InitToolRule(**data)
|
||||
elif rule_type == ToolRuleType.exit_loop or rule_type == "TerminalToolRule":
|
||||
data["type"] = ToolRuleType.exit_loop
|
||||
return TerminalToolRule(**data)
|
||||
elif rule_type == ToolRuleType.constrain_child_tools or rule_type == "ToolRule":
|
||||
data["type"] = ToolRuleType.constrain_child_tools
|
||||
rule = ChildToolRule(**data)
|
||||
return rule
|
||||
elif rule_type == ToolRuleType.conditional:
|
||||
rule = ConditionalToolRule(**data)
|
||||
return rule
|
||||
else:
|
||||
raise ValueError(f"Unknown tool rule type: {rule_type}")
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_tool_rules(value)
|
||||
|
||||
|
||||
class ToolCallColumn(TypeDecorator):
|
||||
"""Custom SQLAlchemy column type for storing OpenAI ToolCall objects as JSON."""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
values = []
|
||||
for v in value:
|
||||
if isinstance(v, OpenAIToolCall):
|
||||
values.append(v.model_dump())
|
||||
else:
|
||||
values.append(v)
|
||||
return values
|
||||
|
||||
return value
|
||||
return serialize_tool_calls(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
tools = []
|
||||
for tool_value in value:
|
||||
if "function" in tool_value:
|
||||
tool_call_function = OpenAIFunction(**tool_value["function"])
|
||||
del tool_value["function"]
|
||||
else:
|
||||
tool_call_function = None
|
||||
tools.append(OpenAIToolCall(function=tool_call_function, **tool_value))
|
||||
return tools
|
||||
return value
|
||||
return deserialize_tool_calls(value)
|
||||
|
||||
|
||||
class CommonVector(TypeDecorator):
|
||||
"""Common type for representing vectors in SQLite"""
|
||||
"""Custom SQLAlchemy column type for storing vectors in SQLite."""
|
||||
|
||||
impl = BINARY
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(BINARY())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
value = np.array(value, dtype=np.float32)
|
||||
return base64.b64encode(value.tobytes())
|
||||
return serialize_vector(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if not value:
|
||||
return value
|
||||
if dialect.name == "sqlite":
|
||||
value = base64.b64decode(value)
|
||||
return np.frombuffer(value, dtype=np.float32)
|
||||
return deserialize_vector(value, dialect)
|
||||
|
||||
Reference in New Issue
Block a user