From 3dc1767f461346a2529d3307d4fc68b7dafbfed6 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 17 Feb 2025 15:07:12 -0800 Subject: [PATCH] feat: Factor out custom_columns serialize/deserialize logic (#1028) --- letta/helpers/converters.py | 152 +++++++++++++++++++++++++++++++++++ letta/orm/custom_columns.py | 155 +++++++++--------------------------- 2 files changed, 190 insertions(+), 117 deletions(-) create mode 100644 letta/helpers/converters.py diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py new file mode 100644 index 00000000..56757ef8 --- /dev/null +++ b/letta/helpers/converters.py @@ -0,0 +1,152 @@ +import base64 +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall +from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction +from sqlalchemy import Dialect + +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ToolRuleType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule, ToolRule + +# -------------------------- +# LLMConfig Serialization +# -------------------------- + + +def serialize_llm_config(config: Union[Optional[LLMConfig], Dict]) -> Optional[Dict]: + """Convert an LLMConfig object into a JSON-serializable dictionary.""" + if config and isinstance(config, LLMConfig): + return config.model_dump() + return config + + +def deserialize_llm_config(data: Optional[Dict]) -> Optional[LLMConfig]: + """Convert a dictionary back into an LLMConfig object.""" + return LLMConfig(**data) if data else None + + +# -------------------------- +# EmbeddingConfig Serialization +# -------------------------- + + +def serialize_embedding_config(config: Union[Optional[EmbeddingConfig], Dict]) -> Optional[Dict]: + """Convert an EmbeddingConfig object into a JSON-serializable dictionary.""" + if config and isinstance(config, EmbeddingConfig): + return config.model_dump() + return config + + +def deserialize_embedding_config(data: Optional[Dict]) -> Optional[EmbeddingConfig]: + """Convert a dictionary back into an EmbeddingConfig object.""" + return EmbeddingConfig(**data) if data else None + + +# -------------------------- +# ToolRule Serialization +# -------------------------- + + +def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str, Any]]: + """Convert a list of ToolRules into a JSON-serializable format.""" + + if not tool_rules: + return [] + + data = [{**rule.model_dump(), "type": rule.type.value} for rule in tool_rules] # Convert Enum to string for JSON compatibility + + # Validate ToolRule structure + for rule_data in data: + if rule_data["type"] == ToolRuleType.constrain_child_tools.value and "children" not in rule_data: + raise ValueError(f"Invalid ToolRule serialization: 'children' field missing for rule {rule_data}") + + return data + + +def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]]: + """Convert a list of dictionaries back into ToolRule objects.""" + if not data: + return [] + + return [deserialize_tool_rule(rule_data) for rule_data in data] + + +def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]: + """Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'.""" + rule_type = ToolRuleType(data.get("type")) + + if rule_type == ToolRuleType.run_first: + return InitToolRule(**data) + elif rule_type == ToolRuleType.exit_loop: + return TerminalToolRule(**data) + elif rule_type == ToolRuleType.constrain_child_tools: + return ChildToolRule(**data) + elif rule_type == ToolRuleType.conditional: + return ConditionalToolRule(**data) + + raise ValueError(f"Unknown ToolRule type: {rule_type}") + + +# -------------------------- +# ToolCall Serialization +# -------------------------- + + +def serialize_tool_calls(tool_calls: Optional[List[Union[OpenAIToolCall, dict]]]) -> List[Dict]: + """Convert a list of OpenAI ToolCall objects into JSON-serializable format.""" + if not tool_calls: + return [] + + serialized_calls = [] + for call in tool_calls: + if isinstance(call, OpenAIToolCall): + serialized_calls.append(call.model_dump()) + elif isinstance(call, dict): + serialized_calls.append(call) # Already a dictionary, leave it as-is + else: + raise TypeError(f"Unexpected tool call type: {type(call)}") + + return serialized_calls + + +def deserialize_tool_calls(data: Optional[List[Dict]]) -> List[OpenAIToolCall]: + """Convert a JSON list back into OpenAIToolCall objects.""" + if not data: + return [] + + calls = [] + for item in data: + func_data = item.pop("function", None) + tool_call_function = OpenAIFunction(**func_data) if func_data else None + calls.append(OpenAIToolCall(function=tool_call_function, **item)) + + return calls + + +# -------------------------- +# Vector Serialization +# -------------------------- + + +def serialize_vector(vector: Optional[Union[List[float], np.ndarray]]) -> Optional[bytes]: + """Convert a NumPy array or list into a base64-encoded byte string.""" + if vector is None: + return None + if isinstance(vector, list): + vector = np.array(vector, dtype=np.float32) + + return base64.b64encode(vector.tobytes()) + + +def deserialize_vector(data: Optional[bytes], dialect: Dialect) -> Optional[np.ndarray]: + """Convert a base64-encoded byte string back into a NumPy array.""" + if not data: + return None + + if dialect.name == "sqlite": + data = base64.b64decode(data) + + return np.frombuffer(data, dtype=np.float32) diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py index 43de03d2..13810156 100644 --- a/letta/orm/custom_columns.py +++ b/letta/orm/custom_columns.py @@ -1,159 +1,80 @@ -import base64 -from typing import List, Union - -import numpy as np -from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall -from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction from sqlalchemy import JSON from sqlalchemy.types import BINARY, TypeDecorator -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import ToolRuleType -from letta.schemas.llm_config import LLMConfig -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule - - -class EmbeddingConfigColumn(TypeDecorator): - """Custom type for storing EmbeddingConfig as JSON.""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value and isinstance(value, EmbeddingConfig): - return value.model_dump() - return value - - def process_result_value(self, value, dialect): - if value: - return EmbeddingConfig(**value) - return value +from letta.helpers.converters import ( + deserialize_embedding_config, + deserialize_llm_config, + deserialize_tool_calls, + deserialize_tool_rules, + deserialize_vector, + serialize_embedding_config, + serialize_llm_config, + serialize_tool_calls, + serialize_tool_rules, + serialize_vector, +) class LLMConfigColumn(TypeDecorator): - """Custom type for storing LLMConfig as JSON.""" + """Custom SQLAlchemy column type for storing LLMConfig as JSON.""" impl = JSON cache_ok = True - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - def process_bind_param(self, value, dialect): - if value and isinstance(value, LLMConfig): - return value.model_dump() - return value + return serialize_llm_config(value) def process_result_value(self, value, dialect): - if value: - return LLMConfig(**value) - return value + return deserialize_llm_config(value) + + +class EmbeddingConfigColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing EmbeddingConfig as JSON.""" + + impl = JSON + cache_ok = True + + def process_bind_param(self, value, dialect): + return serialize_embedding_config(value) + + def process_result_value(self, value, dialect): + return deserialize_embedding_config(value) class ToolRulesColumn(TypeDecorator): - """Custom type for storing a list of ToolRules as JSON""" + """Custom SQLAlchemy column type for storing a list of ToolRules as JSON.""" impl = JSON cache_ok = True - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - def process_bind_param(self, value, dialect): - """Convert a list of ToolRules to JSON-serializable format.""" - if value: - data = [rule.model_dump() for rule in value] - for d in data: - d["type"] = d["type"].value + return serialize_tool_rules(value) - for d in data: - assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field" - return data - return value - - def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]: - """Convert JSON back to a list of ToolRules.""" - if value: - return [self.deserialize_tool_rule(rule_data) for rule_data in value] - return value - - @staticmethod - def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]: - """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" - rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var - if rule_type == ToolRuleType.run_first or rule_type == "InitToolRule": - data["type"] = ToolRuleType.run_first - return InitToolRule(**data) - elif rule_type == ToolRuleType.exit_loop or rule_type == "TerminalToolRule": - data["type"] = ToolRuleType.exit_loop - return TerminalToolRule(**data) - elif rule_type == ToolRuleType.constrain_child_tools or rule_type == "ToolRule": - data["type"] = ToolRuleType.constrain_child_tools - rule = ChildToolRule(**data) - return rule - elif rule_type == ToolRuleType.conditional: - rule = ConditionalToolRule(**data) - return rule - else: - raise ValueError(f"Unknown tool rule type: {rule_type}") + def process_result_value(self, value, dialect): + return deserialize_tool_rules(value) class ToolCallColumn(TypeDecorator): + """Custom SQLAlchemy column type for storing OpenAI ToolCall objects as JSON.""" impl = JSON cache_ok = True - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - def process_bind_param(self, value, dialect): - if value: - values = [] - for v in value: - if isinstance(v, OpenAIToolCall): - values.append(v.model_dump()) - else: - values.append(v) - return values - - return value + return serialize_tool_calls(value) def process_result_value(self, value, dialect): - if value: - tools = [] - for tool_value in value: - if "function" in tool_value: - tool_call_function = OpenAIFunction(**tool_value["function"]) - del tool_value["function"] - else: - tool_call_function = None - tools.append(OpenAIToolCall(function=tool_call_function, **tool_value)) - return tools - return value + return deserialize_tool_calls(value) class CommonVector(TypeDecorator): - """Common type for representing vectors in SQLite""" + """Custom SQLAlchemy column type for storing vectors in SQLite.""" impl = BINARY cache_ok = True - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(BINARY()) - def process_bind_param(self, value, dialect): - if value is None: - return value - if isinstance(value, list): - value = np.array(value, dtype=np.float32) - return base64.b64encode(value.tobytes()) + return serialize_vector(value) def process_result_value(self, value, dialect): - if not value: - return value - if dialect.name == "sqlite": - value = base64.b64decode(value) - return np.frombuffer(value, dtype=np.float32) + return deserialize_vector(value, dialect)