TM-SGNL-iOS/Scripts/sds_codegen/sds_generate.py
TeleMessage developers dde0620daf initial commit
2025-05-03 12:28:28 -07:00

2956 lines
96 KiB
Python
Executable file

#!/usr/bin/env python3
import os
import subprocess
import argparse
import re
import json
import sds_common
from sds_common import fail
import random
# TODO: We should probably generate a class that knows how to set up
# the database. It would:
#
# * Create all tables (or apply database schema).
# * Register renamed classes.
# [NSKeyedUnarchiver setClass:[OWSUserProfile class] forClassName:[OWSUserProfile collection]];
# [NSKeyedUnarchiver setClass:[OWSDatabaseMigration class] forClassName:[OWSDatabaseMigration collection]];
# We consider any subclass of TSYapDatabaseObject to be a "serializable model".
#
# We treat direct subclasses of TSYapDatabaseObject as "roots" of the model class hierarchy.
# Only root models do deserialization.
OLD_BASE_MODEL_CLASS_NAME = "TSYapDatabaseObject"
NEW_BASE_MODEL_CLASS_NAME = "BaseModel"
CODE_GEN_SNIPPET_MARKER_OBJC = "// --- CODE GENERATION MARKER"
# GRDB seems to encode non-primitive using JSON.
# GRDB chokes when decodes this JSON, due to it being a JSON "fragment".
# Either this is a bug in GRDB or we're using GRDB incorrectly.
# Until we resolve this issue, we need to encode/decode
# non-primitives ourselves.
USE_CODABLE_FOR_PRIMITIVES = False
USE_CODABLE_FOR_NONPRIMITIVES = False
def update_generated_snippet(file_path, marker, snippet):
# file_path = sds_common.sds_from_relative_path(relative_path)
if not os.path.exists(file_path):
fail("Missing file:", file_path)
with open(file_path, "rt") as f:
text = f.read()
start_index = text.find(marker)
end_index = text.rfind(marker)
if start_index < 0 or end_index < 0 or start_index >= end_index:
fail(f"Could not find markers ('{marker}'): {file_path}")
text = (
text[:start_index].strip()
+ "\n\n"
+ marker
+ "\n\n"
+ snippet
+ "\n\n"
+ marker
+ "\n\n"
+ text[end_index + len(marker) :].lstrip()
)
sds_common.write_text_file_if_changed(file_path, text)
def update_objc_snippet(file_path, snippet):
snippet = sds_common.clean_up_generated_objc(snippet).strip()
if len(snippet) < 1:
return
snippet = (
"// This snippet is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
% (sds_common.pretty_module_path(__file__),)
+ "\n\n"
+ snippet
)
update_generated_snippet(file_path, CODE_GEN_SNIPPET_MARKER_OBJC, snippet)
# ----
global_class_map = {}
global_subclass_map = {}
global_args = None
# ----
def to_swift_identifier_name(identifier_name):
return identifier_name[0].lower() + identifier_name[1:]
class ParsedClass:
def __init__(self, json_dict):
self.name = json_dict.get("name")
self.super_class_name = json_dict.get("super_class_name")
self.filepath = sds_common.sds_from_relative_path(json_dict.get("filepath"))
self.finalize_method_name = json_dict.get("finalize_method_name")
self.property_map = {}
for property_dict in json_dict.get("properties"):
property = ParsedProperty(property_dict)
property.class_name = self.name
# TODO: We should handle all properties?
if property.should_ignore_property():
continue
self.property_map[property.name] = property
def properties(self):
result = []
for name in sorted(self.property_map.keys()):
result.append(self.property_map[name])
return result
def database_subclass_properties(self):
# More than one subclass of a SDS model may declare properties
# with the same name. This is fine, so long as they have
# the same type.
all_property_map = {}
subclass_property_map = {}
root_property_names = set()
for property in self.properties():
all_property_map[property.name] = property
root_property_names.add(property.name)
for subclass in all_descendents_of_class(self):
if should_ignore_class(subclass):
continue
for property in subclass.properties():
duplicate_property = all_property_map.get(property.name)
if duplicate_property is not None:
if (
property.swift_type_safe()
!= duplicate_property.swift_type_safe()
):
print(
"property:",
property.class_name,
property.name,
property.swift_type_safe(),
property.is_optional,
)
print(
"duplicate_property:",
duplicate_property.class_name,
duplicate_property.name,
duplicate_property.swift_type_safe(),
duplicate_property.is_optional,
)
fail("Duplicate property doesn't match:", property.name)
elif property.is_optional != duplicate_property.is_optional:
if property.name in root_property_names:
print(
"property:",
property.class_name,
property.name,
property.swift_type_safe(),
property.is_optional,
)
print(
"duplicate_property:",
duplicate_property.class_name,
duplicate_property.name,
duplicate_property.swift_type_safe(),
duplicate_property.is_optional,
)
fail("Duplicate property doesn't match:", property.name)
# If one subclass property is optional and the other isn't, we should
# treat both as optional for the purposes of the database schema.
if not property.is_optional:
continue
else:
continue
all_property_map[property.name] = property
subclass_property_map[property.name] = property
result = []
for name in sorted(subclass_property_map.keys()):
result.append(subclass_property_map[name])
return result
def record_id_source(self):
for property in self.properties():
if property.name == "sortId":
return property.name
return None
def is_sds_model(self):
if self.super_class_name is None:
return False
if not self.super_class_name in global_class_map:
return False
if self.super_class_name in (
OLD_BASE_MODEL_CLASS_NAME,
NEW_BASE_MODEL_CLASS_NAME,
):
return True
super_class = global_class_map[self.super_class_name]
return super_class.is_sds_model()
def has_sds_superclass(self):
return (
self.super_class_name
and self.super_class_name in global_class_map
and self.super_class_name != OLD_BASE_MODEL_CLASS_NAME
and self.super_class_name != NEW_BASE_MODEL_CLASS_NAME
)
def table_superclass(self):
if self.super_class_name is None:
return self
if not self.super_class_name in global_class_map:
return self
if self.super_class_name == OLD_BASE_MODEL_CLASS_NAME:
return self
if self.super_class_name == NEW_BASE_MODEL_CLASS_NAME:
return self
super_class = global_class_map[self.super_class_name]
return super_class.table_superclass()
def all_superclass_names(self):
result = [self.name]
if self.super_class_name is not None:
if self.super_class_name in global_class_map:
super_class = global_class_map[self.super_class_name]
result += super_class.all_superclass_names()
return result
def has_any_superclass_with_name(self, name):
return name in self.all_superclass_names()
def should_generate_extensions(self):
if self.name in (
OLD_BASE_MODEL_CLASS_NAME,
NEW_BASE_MODEL_CLASS_NAME,
):
return False
if should_ignore_class(self):
return False
if not self.is_sds_model():
# Only write serialization extensions for SDS models.
return False
# The migration should not be persisted in the data store.
if self.name in (
"OWSDatabaseMigration",
"YDBDatabaseMigration",
"OWSResaveCollectionDBMigration",
):
return False
if self.super_class_name in (
"OWSDatabaseMigration",
"YDBDatabaseMigration",
"OWSResaveCollectionDBMigration",
):
return False
return True
def record_name(self):
return remove_prefix_from_class_name(self.name) + "Record"
def sorted_record_properties(self):
record_name = self.record_name()
# If a property has a custom column source, we don't redundantly create a column for that column
base_properties = [
property
for property in self.properties()
if not property.has_aliased_column_name()
]
# If a property has a custom column source, we don't redundantly create a column for that column
subclass_properties = [
property
for property in self.database_subclass_properties()
if not property.has_aliased_column_name()
]
# We need to maintain a stable ordering of record properties
# across migrations, e.g. adding new columns to the tables.
#
# First, we build a list of "model" properties. This is the
# the superset of properties in the model base class and all
# of its subclasses.
#
# NOTE: We punch two values onto these properties:
# force_optional and property_order.
record_properties = []
for property in base_properties:
property.force_optional = False
record_properties.append(property)
for property in subclass_properties:
# We must "force" subclass properties to be optional
# since they don't apply to the base model and other
# subclasses.
property.force_optional = True
record_properties.append(property)
for property in record_properties:
# Try to load the known "order" for each property.
#
# "Orders" are indices used to ensure a stable ordering.
# We find the "orders" of all properties that already have
# one.
#
# This will initially be nil for new properties
# which have not yet been assigned an order.
property.property_order = property_order_for_property(property, record_name)
all_property_orders = [
property.property_order
for property in record_properties
if property.property_order
]
# We determine the "next" order we would assign to any
# new property without an order.
next_property_order = 1 + (
max(all_property_orders) if len(all_property_orders) > 0 else 0
)
# Pre-sort model properties by name, so that if we add more
# than one at a time they are nicely (and stable-y) sorted
# in an attractive way.
record_properties.sort(key=lambda value: value.name)
# Now iterate over all model properties and assign an order
# to any new properties without one.
for property in record_properties:
if property.property_order is None:
property.property_order = next_property_order
# We "set" the order in the mapping which is persisted
# as JSON to ensure continuity.
set_property_order_for_property(
property, record_name, next_property_order
)
next_property_order = next_property_order + 1
# Now sort the model properties, applying the ordering.
record_properties.sort(key=lambda value: value.property_order)
return record_properties
class TypeInfo:
def __init__(
self,
swift_type,
objc_type,
should_use_blob=False,
is_codable=False,
is_enum=False,
field_override_column_type=None,
field_override_record_swift_type=None,
):
self._swift_type = swift_type
self._objc_type = objc_type
self.should_use_blob = should_use_blob
self.is_codable = is_codable
self.is_enum = is_enum
self.field_override_column_type = field_override_column_type
self.field_override_record_swift_type = field_override_record_swift_type
def swift_type(self):
return str(self._swift_type)
def objc_type(self):
return str(self._objc_type)
# This defines the mapping of Swift types to database column types.
# We'll be iterating on this mapping.
# Note that we currently store all sub-models and collections (e.g. [String]) as a blob.
#
# TODO:
def database_column_type(self, value_name):
if self.field_override_column_type is not None:
return self.field_override_column_type
elif self.should_use_blob or self.is_codable:
return ".blob"
elif self.is_enum:
return ".int"
elif self._swift_type == "String":
return ".unicodeString"
elif self._objc_type == "NSDate *":
# Persist dates as NSTimeInterval timeIntervalSince1970.
return ".double"
elif self._swift_type == "Date":
# Persist dates as NSTimeInterval timeIntervalSince1970.
fail(
'We should not use `Date` as a "swift type" since all NSDates are serialized as doubles.',
self._swift_type,
)
elif self._swift_type == "Data":
return ".blob"
elif self._swift_type in ("Boolouble", "Bool"):
return ".int"
elif self._swift_type in ("Double", "Float"):
return ".double"
elif self.is_numeric():
return ".int64"
else:
fail("Unknown type(1):", self._swift_type)
def is_numeric(self):
# TODO: We need to revisit how we serialize numeric types.
return self._swift_type in (
# 'signed char',
"Bool",
"UInt64",
"UInt",
"Int64",
"Int",
"Int32",
"UInt32",
"Double",
"Float",
)
def should_cast_to_swift(self):
if self._swift_type in (
"Bool",
"Int64",
"UInt64",
):
return False
return self.is_numeric()
def deserialize_record_invocation(
self, property, value_name, is_optional, did_force_optional
):
value_expr = "record.%s" % (property.column_name(),)
deserialization_optional = None
deserialization_not_optional = None
deserialization_conversion = ""
if self._swift_type == "String":
deserialization_not_optional = "required"
elif self._objc_type == "NSDate *":
pass
elif self._swift_type == "Date":
fail("Unknown type(0):", self._swift_type)
elif self.is_codable:
deserialization_not_optional = "required"
elif self._swift_type == "Data":
deserialization_optional = "optionalData"
deserialization_not_optional = "required"
elif self.is_numeric():
deserialization_optional = "optionalNumericAsNSNumber"
deserialization_not_optional = "required"
deserialization_conversion = ", conversion: { NSNumber(value: $0) }"
initializer_param_type = self.swift_type()
if is_optional:
initializer_param_type = initializer_param_type + "?"
# Special-case the unpacking of the auto-incremented
# primary key.
if value_expr == "record.id":
value_expr = "%s(recordId)" % (initializer_param_type,)
elif is_optional:
if deserialization_optional is not None:
value_expr = 'SDSDeserialization.%s(%s, name: "%s"%s)' % (
deserialization_optional,
value_expr,
value_name,
deserialization_conversion,
)
elif did_force_optional:
if deserialization_not_optional is not None:
value_expr = 'try SDSDeserialization.%s(%s, name: "%s")' % (
deserialization_not_optional,
value_expr,
value_name,
)
else:
# Do nothing; we don't need to unpack this non-optional.
pass
if value_name == "conversationColorName":
value_statement = "let %s: %s = ConversationColorName(rawValue: %s)" % (
value_name,
"ConversationColorName",
value_expr,
)
elif value_name == "mentionNotificationMode":
value_statement = (
"let %s: %s = TSThreadMentionNotificationMode(rawValue: %s) ?? .default"
% (
value_name,
"TSThreadMentionNotificationMode",
value_expr,
)
)
elif value_name == "storyViewMode":
value_statement = (
"let %s: %s = TSThreadStoryViewMode(rawValue: %s) ?? .default"
% (
value_name,
"TSThreadStoryViewMode",
value_expr,
)
)
elif self.is_codable:
value_statement = "let %s: %s = %s" % (
value_name,
initializer_param_type,
value_expr,
)
elif self.should_use_blob:
blob_name = "%sSerialized" % (str(value_name),)
if is_optional or did_force_optional:
serialized_statement = "let %s: Data? = %s" % (
blob_name,
value_expr,
)
else:
serialized_statement = "let %s: Data = %s" % (
blob_name,
value_expr,
)
if is_optional:
value_statement = (
'let %s: %s? = try SDSDeserialization.optionalUnarchive(%s, name: "%s")'
% (
value_name,
self._swift_type,
blob_name,
value_name,
)
)
else:
value_statement = (
'let %s: %s = try SDSDeserialization.unarchive(%s, name: "%s")'
% (
value_name,
self._swift_type,
blob_name,
value_name,
)
)
return [
serialized_statement,
value_statement,
]
elif self.is_enum and did_force_optional and not is_optional:
return [
"guard let %s: %s = %s else {"
% (
value_name,
initializer_param_type,
value_expr,
),
" throw SDSError.missingRequiredField()",
"}",
]
elif is_optional and self._objc_type == "NSNumber *":
return [
"let %s: %s = %s"
% (
value_name,
"NSNumber?",
value_expr,
),
# 'let %sRaw = %s' % ( value_name, value_expr, ),
# 'var %s : NSNumber?' % ( value_name, ),
# 'if let value = %sRaw {' % ( value_name, ),
# ' %s = NSNumber(value: value)' % ( value_name, ),
# '}',
]
elif self._objc_type == "NSDate *":
# Persist dates as NSTimeInterval timeIntervalSince1970.
interval_name = "%sInterval" % (str(value_name),)
if did_force_optional:
serialized_statements = [
"guard let %s: Double = %s else {"
% (
interval_name,
value_expr,
),
" throw SDSError.missingRequiredField()",
"}",
]
elif is_optional:
serialized_statements = [
"let %s: Double? = %s"
% (
interval_name,
value_expr,
),
]
else:
serialized_statements = [
"let %s: Double = %s"
% (
interval_name,
value_expr,
),
]
if is_optional:
value_statement = (
'let %s: Date? = SDSDeserialization.optionalDoubleAsDate(%s, name: "%s")'
% (
value_name,
interval_name,
value_name,
)
)
else:
value_statement = (
'let %s: Date = SDSDeserialization.requiredDoubleAsDate(%s, name: "%s")'
% (
value_name,
interval_name,
value_name,
)
)
return serialized_statements + [
value_statement,
]
else:
value_statement = "let %s: %s = %s" % (
value_name,
initializer_param_type,
value_expr,
)
return [
value_statement,
]
def serialize_record_invocation(
self, property, value_name, is_optional, did_force_optional
):
value_expr = value_name
if property.field_override_serialize_record_invocation() is not None:
return property.field_override_serialize_record_invocation() % (value_expr,)
elif self.is_codable:
pass
elif self.should_use_blob:
# blob_name = '%sSerialized' % ( str(value_name), )
if is_optional or did_force_optional:
return "optionalArchive(%s)" % (value_expr,)
else:
return "requiredArchive(%s)" % (value_expr,)
elif self._objc_type == "NSDate *":
if is_optional or did_force_optional:
return "archiveOptionalDate(%s)" % (value_expr,)
else:
return "archiveDate(%s)" % (value_expr,)
elif self._objc_type == "NSNumber *":
# elif self.is_numeric():
conversion_map = {
"Int8": "int8Value",
"UInt8": "uint8Value",
"Int16": "int16Value",
"UInt16": "uint16Value",
"Int32": "int32Value",
"UInt32": "uint32Value",
"Int64": "int64Value",
"UInt64": "uint64Value",
"Float": "floatValue",
"Double": "doubleValue",
"Bool": "boolValue",
"Int": "intValue",
"UInt": "uintValue",
}
conversion_method = conversion_map[self.swift_type()]
if conversion_method is None:
fail("Could not convert:", self.swift_type())
serialization_conversion = "{ $0.%s }" % (conversion_method,)
if is_optional or did_force_optional:
return "archiveOptionalNSNumber(%s, conversion: %s)" % (
value_expr,
serialization_conversion,
)
else:
return "archiveNSNumber(%s, conversion: %s)" % (
value_expr,
serialization_conversion,
)
return value_expr
def record_field_type(self, value_name):
# Special case this oddball type.
if self.field_override_record_swift_type is not None:
return self.field_override_record_swift_type
elif self.is_codable:
pass
elif self.should_use_blob:
return "Data"
return self.swift_type()
class ParsedProperty:
def __init__(self, json_dict):
self.name = json_dict.get("name")
self.is_optional = json_dict.get("is_optional")
self.objc_type = json_dict.get("objc_type")
self.class_name = json_dict.get("class_name")
self.swift_type = None
def try_to_convert_objc_primitive_to_swift(self, objc_type, unpack_nsnumber=True):
if objc_type is None:
fail("Missing type")
elif objc_type == "NSString *":
return "String"
elif objc_type == "NSDate *":
# Persist dates as NSTimeInterval timeIntervalSince1970.
return "Double"
elif objc_type == "NSData *":
return "Data"
elif objc_type == "BOOL":
return "Bool"
elif objc_type == "NSInteger":
return "Int"
elif objc_type == "NSUInteger":
return "UInt"
elif objc_type == "int32_t":
return "Int32"
elif objc_type == "uint32_t":
return "UInt32"
elif objc_type == "int64_t":
return "Int64"
elif objc_type == "long long":
return "Int64"
elif objc_type == "unsigned long long":
return "UInt64"
elif objc_type == "uint64_t":
return "UInt64"
elif objc_type == "unsigned long":
return "UInt64"
elif objc_type == "unsigned int":
return "UInt32"
elif objc_type == "double":
return "Double"
elif objc_type == "float":
return "Float"
elif objc_type == "CGFloat":
return "Double"
elif objc_type == "NSNumber *":
if unpack_nsnumber:
return swift_type_for_nsnumber(self)
else:
return "NSNumber"
else:
return None
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
def convert_objc_class_to_swift(self, objc_type, unpack_nsnumber=True):
if objc_type == "id":
return "AnyObject"
elif not objc_type.endswith(" *"):
return None
swift_primitive = self.try_to_convert_objc_primitive_to_swift(
objc_type, unpack_nsnumber=unpack_nsnumber
)
if swift_primitive is not None:
return swift_primitive
array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type)
if array_match is not None:
split = array_match.group(2)
return (
"["
+ self.convert_objc_class_to_swift(split, unpack_nsnumber=False)
+ "]"
)
dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return (
"["
+ self.convert_objc_class_to_swift(split1, unpack_nsnumber=False)
+ ": "
+ self.convert_objc_class_to_swift(split2, unpack_nsnumber=False)
+ "]"
)
ordered_set_match = re.search(r"^NSOrderedSet<(.+)> \*$", objc_type)
if ordered_set_match is not None:
# swift has no primitive for ordered set, so we lose the element type
return "NSOrderedSet"
swift_type = objc_type[: -len(" *")]
if "<" in swift_type or "{" in swift_type or "*" in swift_type:
fail("Unexpected type:", objc_type)
return swift_type
def try_to_convert_objc_type_to_type_info(self):
objc_type = self.objc_type
if objc_type is None:
fail("Missing type")
elif self.field_override_swift_type():
return TypeInfo(
self.field_override_swift_type(),
objc_type,
should_use_blob=self.field_override_should_use_blob(),
is_enum=self.field_override_is_enum(),
field_override_column_type=self.field_override_column_type(),
field_override_record_swift_type=self.field_override_record_swift_type(),
)
elif objc_type in enum_type_map:
enum_type = objc_type
return TypeInfo(enum_type, objc_type, is_enum=True)
elif objc_type.startswith("enum "):
enum_type = objc_type[len("enum ") :]
return TypeInfo(enum_type, objc_type, is_enum=True)
swift_primitive = self.try_to_convert_objc_primitive_to_swift(objc_type)
if swift_primitive is not None:
return TypeInfo(swift_primitive, objc_type)
if objc_type in (
"struct CGSize",
"struct CGRect",
"struct CGPoint",
):
objc_type = objc_type[len("struct ") :]
swift_type = objc_type
return TypeInfo(
swift_type,
objc_type,
should_use_blob=True,
is_codable=USE_CODABLE_FOR_PRIMITIVES,
)
swift_type = self.convert_objc_class_to_swift(self.objc_type)
if swift_type is not None:
if self.is_objc_type_codable(objc_type):
return TypeInfo(
swift_type, objc_type, should_use_blob=True, is_codable=False
)
return TypeInfo(
swift_type, objc_type, should_use_blob=True, is_codable=False
)
fail("Unknown type(3):", self.class_name, self.objc_type, self.name)
# NOTE: This method recurses to unpack types like: NSArray<NSArray<SomeClassName *> *> *
def is_objc_type_codable(self, objc_type):
if not USE_CODABLE_FOR_PRIMITIVES:
return False
if objc_type in ("NSString *",):
return True
elif objc_type in (
"struct CGSize",
"struct CGRect",
"struct CGPoint",
):
return True
elif self.field_override_is_objc_codable() is not None:
return self.field_override_is_objc_codable()
elif objc_type in enum_type_map:
return True
elif objc_type.startswith("enum "):
return True
if not USE_CODABLE_FOR_NONPRIMITIVES:
return False
array_match = re.search(r"^NS(Mutable)?Array<(.+)> \*$", objc_type)
if array_match is not None:
split = array_match.group(2)
return self.is_objc_type_codable(split)
dict_match = re.search(r"^NS(Mutable)?Dictionary<(.+),(.+)> \*$", objc_type)
if dict_match is not None:
split1 = dict_match.group(2).strip()
split2 = dict_match.group(3).strip()
return self.is_objc_type_codable(split1) and self.is_objc_type_codable(
split2
)
return False
def field_override_swift_type(self):
return self._field_override("swift_type")
def field_override_is_objc_codable(self):
return self._field_override("is_objc_codable")
def field_override_is_enum(self):
return self._field_override("is_enum")
def field_override_column_type(self):
return self._field_override("column_type")
def field_override_record_swift_type(self):
return self._field_override("record_swift_type")
def field_override_serialize_record_invocation(self):
return self._field_override("serialize_record_invocation")
def field_override_should_use_blob(self):
return self._field_override("should_use_blob")
def field_override_objc_initializer_type(self):
return self._field_override("objc_initializer_type")
def _field_override(self, override_field):
manually_typed_fields = configuration_json.get("manually_typed_fields")
if manually_typed_fields is None:
fail("Configuration JSON is missing manually_typed_fields")
key = self.class_name + "." + self.name
if key in manually_typed_fields:
return manually_typed_fields[key][override_field]
else:
return None
def type_info(self):
if self.swift_type is not None:
should_use_blob = (
self.swift_type.startswith("[")
or self.swift_type.startswith("{")
or is_swift_class_name(self.swift_type)
)
return TypeInfo(
self.swift_type,
objc_type,
should_use_blob=should_use_blob,
is_codable=USE_CODABLE_FOR_PRIMITIVES,
field_override_column_type=self.field_override_column_type,
)
return self.try_to_convert_objc_type_to_type_info()
def swift_type_safe(self):
return self.type_info().swift_type()
def objc_type_safe(self):
if self.field_override_objc_initializer_type() is not None:
return self.field_override_objc_initializer_type()
result = self.type_info().objc_type()
if result.startswith("enum "):
result = result[len("enum ") :]
return result
# if self.objc_type is None:
# fail("Don't know Obj-C type for:", self.name)
# return self.objc_type
def database_column_type(self):
return self.type_info().database_column_type(self.name)
def should_ignore_property(self):
return should_ignore_property(self)
def has_aliased_column_name(self):
return aliased_column_name_for_property(self) is not None
def deserialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().deserialize_record_invocation(
self, value_name, self.is_optional, did_force_optional
)
def deep_copy_record_invocation(self, value_name, did_force_optional):
swift_type = self.swift_type_safe()
objc_type = self.objc_type_safe()
is_optional = self.is_optional
model_accessor = accessor_name_for_property(self)
initializer_param_type = swift_type
if is_optional:
initializer_param_type = initializer_param_type + "?"
simple_type_map = {
"NSString *": "String",
"NSNumber *": "NSNumber",
"NSDate *": "Date",
"NSData *": "Data",
"CGSize": "CGSize",
"CGRect": "CGRect",
"CGPoint": "CGPoint",
}
if objc_type in simple_type_map:
initializer_param_type = simple_type_map[objc_type]
if is_optional:
initializer_param_type += "?"
return [
"let %s: %s = modelToCopy.%s"
% (
value_name,
initializer_param_type,
model_accessor,
),
]
can_shallow_copy = False
if self.type_info().is_numeric():
can_shallow_copy = True
elif self.is_enum():
can_shallow_copy = True
if can_shallow_copy:
return [
"let %s: %s = modelToCopy.%s"
% (
value_name,
initializer_param_type,
model_accessor,
),
]
initializer_param_type = initializer_param_type.replace("AnyObject", "Any")
if is_optional:
return [
"let %s: %s"
% (
value_name,
initializer_param_type,
),
"if let %sForCopy = modelToCopy.%s {"
% (
value_name,
model_accessor,
),
" %s = try DeepCopies.deepCopy(%sForCopy)"
% (
value_name,
value_name,
),
"} else {",
" %s = nil" % (value_name,),
"}",
]
else:
return [
"let %s: %s = try DeepCopies.deepCopy(modelToCopy.%s)"
% (
value_name,
initializer_param_type,
model_accessor,
),
]
fail(
"I don't know how to deep copy this type: %s / %s" % (objc_type, swift_type)
)
def possible_class_type_for_property(self):
swift_type = self.swift_type_safe()
if swift_type in global_class_map:
return global_class_map[swift_type]
objc_type = self.objc_type_safe()
if objc_type.endswith(" *"):
objc_type = objc_type[:-2]
if objc_type in global_class_map:
return global_class_map[objc_type]
return None
def serialize_record_invocation(self, value_name, did_force_optional):
return self.type_info().serialize_record_invocation(
self, value_name, self.is_optional, did_force_optional
)
def record_field_type(self):
return self.type_info().record_field_type(self.name)
def is_enum(self):
return self.type_info().is_enum
def swift_identifier(self):
return to_swift_identifier_name(self.name)
def column_name(self):
aliased_column_name = aliased_column_name_for_property(self)
if aliased_column_name is not None:
return aliased_column_name
custom_column_name = custom_column_name_for_property(self)
if custom_column_name is not None:
return custom_column_name
else:
return self.swift_identifier()
def ows_getoutput(cmd):
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout, stderr = proc.communicate()
return proc.returncode, stdout, stderr
# ---- Parsing
def properties_and_inherited_properties(clazz):
result = []
if clazz.super_class_name in global_class_map:
super_class = global_class_map[clazz.super_class_name]
result.extend(properties_and_inherited_properties(super_class))
result.extend(clazz.properties())
return result
def generate_swift_extensions_for_model(clazz):
if not clazz.should_generate_extensions():
return
has_sds_superclass = clazz.has_sds_superclass()
has_remove_methods = clazz.name not in ("TSThread", "TSInteraction")
has_grdb_serializer = clazz.name in ("TSInteraction")
swift_filename = os.path.basename(clazz.filepath)
swift_filename = swift_filename[: swift_filename.find(".")] + "+SDS.swift"
swift_filepath = os.path.join(os.path.dirname(clazz.filepath), swift_filename)
record_type = get_record_type(clazz)
# TODO: We'll need to import SignalServiceKit for non-SSK models.
swift_body = """//
// Copyright 2022 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
%simport GRDB
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
""" % (
"" if has_sds_superclass else "public ",
sds_common.pretty_module_path(__file__),
)
if not has_sds_superclass:
# If a property has a custom column source, we don't redundantly create a column for that column
base_properties = [
property
for property in clazz.properties()
if not property.has_aliased_column_name()
]
# If a property has a custom column source, we don't redundantly create a column for that column
subclass_properties = [
property
for property in clazz.database_subclass_properties()
if not property.has_aliased_column_name()
]
swift_body += """
// MARK: - Record
"""
record_name = clazz.record_name()
swift_body += """
public struct %s: SDSRecord {
public weak var delegate: SDSRecordDelegate?
public var tableMetadata: SDSTableMetadata {
%sSerializer.table
}
public static var databaseTableName: String {
%sSerializer.table.tableName
}
public var id: Int64?
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
public let recordType: SDSRecordType
public let uniqueId: String
""" % (
record_name,
str(clazz.name),
str(clazz.name),
)
def write_record_property(property, force_optional=False):
column_name = property.swift_identifier()
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = "?" if is_optional else ""
custom_column_name = custom_column_name_for_property(property)
if custom_column_name is not None:
column_name = custom_column_name
return """ public let %s: %s%s
""" % (
str(column_name),
record_field_type,
optional_split,
)
record_properties = clazz.sorted_record_properties()
# Declare the model properties in the record.
if len(record_properties) > 0:
swift_body += "\n // Properties \n"
for property in record_properties:
swift_body += write_record_property(
property, force_optional=property.force_optional
)
sds_properties = [
ParsedProperty(
{
"name": "id",
"is_optional": False,
"objc_type": "NSInteger",
"class_name": clazz.name,
}
),
ParsedProperty(
{
"name": "recordType",
"is_optional": False,
"objc_type": "NSUInteger",
"class_name": clazz.name,
}
),
ParsedProperty(
{
"name": "uniqueId",
"is_optional": False,
"objc_type": "NSString *",
"class_name": clazz.name,
}
),
]
# We use the pre-sorted collection record_properties so that
# we use the correct property order when generating:
#
# * CodingKeys
# * init(row: Row)
# * The table/column metadata.
persisted_properties = sds_properties + record_properties
swift_body += """
public enum CodingKeys: String, CodingKey, ColumnExpression, CaseIterable {
"""
for property in persisted_properties:
custom_column_name = custom_column_name_for_property(property)
was_property_renamed = was_property_renamed_for_property(property)
if custom_column_name is not None:
if was_property_renamed:
swift_body += """ case %s
""" % (
custom_column_name,
)
else:
swift_body += """ case %s = "%s"
""" % (
custom_column_name,
property.swift_identifier(),
)
else:
swift_body += """ case %s
""" % (
property.swift_identifier(),
)
swift_body += """ }
"""
swift_body += """
public static func columnName(_ column: %s.CodingKeys, fullyQualified: Bool = false) -> String {
fullyQualified ? "\\(databaseTableName).\\(column.rawValue)" : column.rawValue
}
public func didInsert(with rowID: Int64, for column: String?) {
guard let delegate = delegate else {
owsFailDebug("Missing delegate.")
return
}
delegate.updateRowId(rowID)
}
}
""" % (
record_name,
)
swift_body += """
// MARK: - Row Initializer
public extension %s {
static var databaseSelection: [SQLSelectable] {
CodingKeys.allCases
}
init(row: Row) {""" % (
record_name
)
for index, property in enumerate(persisted_properties):
swift_body += """
%s = row[%s]""" % (
property.column_name(),
index,
)
swift_body += """
}
}
"""
swift_body += """
// MARK: - StringInterpolation
public extension String.StringInterpolation {
mutating func appendInterpolation(%(record_identifier)sColumn column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column))
}
mutating func appendInterpolation(%(record_identifier)sColumnFullyQualified column: %(record_name)s.CodingKeys) {
appendLiteral(%(record_name)s.columnName(column, fullyQualified: true))
}
}
""" % {
"record_identifier": record_identifier(clazz.name),
"record_name": record_name,
}
# TODO: Rework metadata to not include, for example, columns, column indices.
swift_body += """
// MARK: - Deserialization
extension %s {
// This method defines how to deserialize a model, given a
// database row. The recordType column is used to determine
// the corresponding model class.
class func fromRecord(_ record: %s) throws -> %s {
""" % (
str(clazz.name),
record_name,
str(clazz.name),
)
swift_body += """
guard let recordId = record.id else {
throw SDSError.invalidValue()
}
switch record.recordType {
"""
deserialize_classes = all_descendents_of_class(clazz) + [clazz]
deserialize_classes.sort(key=lambda value: value.name)
for deserialize_class in deserialize_classes:
if should_ignore_class(deserialize_class):
continue
initializer_params = []
objc_initializer_params = []
objc_super_initializer_args = []
objc_initializer_assigns = []
deserialize_record_type = get_record_type_enum_name(deserialize_class.name)
swift_body += """ case .%s:
""" % (
str(deserialize_record_type),
)
swift_body += """
let uniqueId: String = record.uniqueId
"""
base_property_names = set()
for property in base_properties:
base_property_names.add(property.name)
deserialize_properties = properties_and_inherited_properties(
deserialize_class
)
has_local_properties = False
for property in deserialize_properties:
value_name = "%s" % property.name
if property.name not in ("uniqueId",):
did_force_optional = (
property.name not in base_property_names
) and (not property.is_optional)
for statement in property.deserialize_record_invocation(
value_name, did_force_optional
):
swift_body += " %s\n" % (str(statement),)
initializer_params.append(
"%s: %s"
% (
str(property.name),
value_name,
)
)
objc_initializer_type = str(property.objc_type_safe())
if objc_initializer_type.startswith("NSMutable"):
objc_initializer_type = (
"NS" + objc_initializer_type[len("NSMutable") :]
)
if property.is_optional:
objc_initializer_type = "nullable " + objc_initializer_type
objc_initializer_params.append(
"%s:(%s)%s"
% (
str(property.name),
objc_initializer_type,
str(property.name),
)
)
is_superclass_property = property.class_name != deserialize_class.name
if is_superclass_property:
objc_super_initializer_args.append(
"%s:%s"
% (
str(property.name),
str(property.name),
)
)
else:
has_local_properties = True
if str(property.objc_type_safe()).startswith("NSMutableArray"):
objc_initializer_assigns.append(
"_%s = %s ? [%s mutableCopy] : [NSMutableArray new];"
% (
str(property.name),
str(property.name),
str(property.name),
)
)
elif str(property.objc_type_safe()).startswith(
"NSMutableDictionary"
):
objc_initializer_assigns.append(
"_%s = %s ? [%s mutableCopy] : [NSMutableDictionary new];"
% (
str(property.name),
str(property.name),
str(property.name),
)
)
elif (
deserialize_class.name == "TSIncomingMessage"
and property.name in ("authorUUID", "authorPhoneNumber")
):
pass
else:
objc_initializer_assigns.append(
"_%s = %s;"
% (
str(property.name),
str(property.name),
)
)
# --- Initializer Snippets
h_snippet = ""
h_snippet += """
// clang-format off
- (instancetype)initWithGrdbId:(int64_t)grdbId
uniqueId:(NSString *)uniqueId
"""
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(
0,
len("- (instancetype)initWithUniqueId")
- objc_initializer_param.index(":"),
)
h_snippet += (" " * alignment) + objc_initializer_param + "\n"
h_snippet += (
"NS_DESIGNATED_INITIALIZER NS_SWIFT_NAME(init(grdbId:%s:));\n"
% ":".join([str(property.name) for property in deserialize_properties])
)
h_snippet += """
// clang-format on
"""
m_snippet = ""
m_snippet += """
// clang-format off
- (instancetype)initWithGrdbId:(int64_t)grdbId
uniqueId:(NSString *)uniqueId
"""
for objc_initializer_param in objc_initializer_params[1:]:
alignment = max(
0,
len("- (instancetype)initWithUniqueId")
- objc_initializer_param.index(":"),
)
m_snippet += (" " * alignment) + objc_initializer_param + "\n"
if len(objc_super_initializer_args) == 1:
suffix = "];"
else:
suffix = ""
m_snippet += """{
self = [super initWithGrdbId:grdbId
uniqueId:uniqueId%s
""" % (
suffix
)
for index, objc_super_initializer_arg in enumerate(
objc_super_initializer_args[1:]
):
alignment = max(
0,
len(" self = [super initWithUniqueId")
- objc_super_initializer_arg.index(":"),
)
if index == len(objc_super_initializer_args) - 2:
suffix = "];"
else:
suffix = ""
m_snippet += (
(" " * alignment) + objc_super_initializer_arg + suffix + "\n"
)
m_snippet += """
if (!self) {
return self;
}
"""
if deserialize_class.name == "TSIncomingMessage":
m_snippet += """
if (authorUUID != nil) {
_authorUUID = authorUUID;
} else if (authorPhoneNumber != nil) {
_authorPhoneNumber = authorPhoneNumber;
}
"""
for objc_initializer_assign in objc_initializer_assigns:
m_snippet += (" " * 4) + objc_initializer_assign + "\n"
if deserialize_class.finalize_method_name is not None:
m_snippet += """
[self %s];
""" % (
str(deserialize_class.finalize_method_name),
)
m_snippet += """
return self;
}
// clang-format on
"""
# Skip initializer generation for classes without any properties.
if not has_local_properties:
h_snippet = ""
m_snippet = ""
if deserialize_class.filepath.endswith(".m"):
m_filepath = deserialize_class.filepath
h_filepath = m_filepath[:-2] + ".h"
update_objc_snippet(h_filepath, h_snippet)
update_objc_snippet(m_filepath, m_snippet)
swift_body += """
"""
# --- Invoke Initializer
initializer_invocation = " return %s(" % str(
deserialize_class.name
)
swift_body += initializer_invocation
initializer_params = [
"grdbId: recordId",
] + initializer_params
swift_body += (",\n" + " " * len(initializer_invocation)).join(
initializer_params
)
swift_body += ")"
swift_body += """
"""
# TODO: We could generate a comment with the Obj-C (or Swift) model initializer
# that this deserialization code expects.
swift_body += """ default:
owsFailDebug("Unexpected record type: \\(record.recordType)")
throw SDSError.invalidValue()
"""
swift_body += """ }
"""
swift_body += """ }
"""
swift_body += """}
"""
# TODO: Remove the serialization glue below.
if not has_sds_superclass:
swift_body += """
// MARK: - SDSModel
extension %s: SDSModel {
public var serializer: SDSSerializer {
// Any subclass can be cast to it's superclass,
// so the order of this switch statement matters.
// We need to do a "depth first" search by type.
switch self {""" % str(
clazz.name
)
for subclass in reversed(all_descendents_of_class(clazz)):
if should_ignore_class(subclass):
continue
swift_body += """
case let model as %s:
assert(type(of: model) == %s.self)
return %sSerializer(model: model)""" % (
str(subclass.name),
str(subclass.name),
str(subclass.name),
)
swift_body += """
default:
return %sSerializer(model: self)
}
}
public func asRecord() -> SDSRecord {
serializer.asRecord()
}
public var sdsTableName: String {
%s.databaseTableName
}
public static var table: SDSTableMetadata {
%sSerializer.table
}
}
""" % (
str(clazz.name),
record_name,
str(clazz.name),
)
if not has_sds_superclass:
swift_body += """
// MARK: - DeepCopyable
extension %(class_name)s: DeepCopyable {
public func deepCopy() throws -> AnyObject {
guard let id = self.grdbId?.int64Value else {
throw OWSAssertionError("Model missing grdbId.")
}
// Any subclass can be cast to its superclass, so the order of these if
// statements matters. We need to do a "depth first" search by type.
""" % {
"class_name": str(clazz.name)
}
classes_to_copy = list(reversed(all_descendents_of_class(clazz))) + [
clazz,
]
for class_to_copy in classes_to_copy:
if should_ignore_class(class_to_copy):
continue
if class_to_copy == clazz:
swift_body += """
do {
let modelToCopy = self
assert(type(of: modelToCopy) == %(class_name)s.self)
""" % {
"class_name": str(class_to_copy.name)
}
else:
swift_body += """
if let modelToCopy = self as? %(class_name)s {
assert(type(of: modelToCopy) == %(class_name)s.self)
""" % {
"class_name": str(class_to_copy.name)
}
initializer_params = []
base_property_names = set()
for property in base_properties:
base_property_names.add(property.name)
deserialize_properties = properties_and_inherited_properties(class_to_copy)
for property in deserialize_properties:
value_name = "%s" % property.name
did_force_optional = (property.name not in base_property_names) and (
not property.is_optional
)
for statement in property.deep_copy_record_invocation(
value_name, did_force_optional
):
swift_body += " %s\n" % (str(statement),)
initializer_params.append(
"%s: %s"
% (
str(property.name),
value_name,
)
)
swift_body += """
"""
# --- Invoke Initializer
initializer_invocation = " return %s(" % str(class_to_copy.name)
swift_body += initializer_invocation
initializer_params = [
"grdbId: id",
] + initializer_params
swift_body += (",\n" + " " * len(initializer_invocation)).join(
initializer_params
)
swift_body += ")"
swift_body += """
}
"""
swift_body += """
}
}
"""
if has_grdb_serializer:
swift_body += """
// MARK: - Table Metadata
extension %sRecord {
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
internal func asArguments() -> StatementArguments {
let databaseValues: [DatabaseValueConvertible?] = [
""" % str(
remove_prefix_from_class_name(clazz.name)
)
def write_grdb_column_metadata(property):
# column_name = property.swift_identifier()
column_name = property.column_name()
return """ %s,
""" % (
str(column_name)
)
for property in sds_properties:
if property.name != "id":
swift_body += write_grdb_column_metadata(property)
if len(record_properties) > 0:
for property in record_properties:
swift_body += write_grdb_column_metadata(property)
swift_body += """
]
return StatementArguments(databaseValues)
}
}
"""
if not has_sds_superclass:
swift_body += """
// MARK: - Table Metadata
extension %sSerializer {
// This defines all of the columns used in the table
// where this model (and any subclasses) are persisted.
""" % str(
clazz.name
)
# Eventually we need a (persistent?) mechanism for guaranteeing
# consistency of column ordering, that is robust to schema
# changes, class hierarchy changes, etc.
column_property_names = []
def write_column_metadata(property, force_optional=False):
column_name = property.swift_identifier()
column_property_names.append(column_name)
is_optional = property.is_optional or force_optional
optional_split = ", isOptional: true" if is_optional else ""
is_unique = column_name == str("uniqueId")
is_unique_split = ", isUnique: true" if is_unique else ""
database_column_type = property.database_column_type()
if property.name == "id":
database_column_type = ".primaryKey"
# TODO: Use skipSelect.
return """ static var %sColumn: SDSColumnMetadata { SDSColumnMetadata(columnName: "%s", columnType: %s%s%s) }
""" % (
str(column_name),
str(column_name),
database_column_type,
optional_split,
is_unique_split,
)
for property in sds_properties:
swift_body += write_column_metadata(property)
if len(record_properties) > 0:
swift_body += " // Properties \n"
for property in record_properties:
swift_body += write_column_metadata(
property, force_optional=property.force_optional
)
database_table_name = "model_%s" % str(clazz.name)
swift_body += """
public static var table: SDSTableMetadata {
SDSTableMetadata(
tableName: "%s",
columns: [
""" % (
database_table_name,
)
swift_body += "\n".join(
[
" %sColumn," % str(column_property_name)
for column_property_name in column_property_names
]
)
swift_body += """
]
)
}
}
"""
# ---- Fetch ----
swift_body += """
// MARK: - Save/Remove/Update
@objc
public extension %(class_name)s {
func anyInsert(transaction: SDSAnyWriteTransaction) {
sdsSave(saveMode: .insert, transaction: transaction)
}
// Avoid this method whenever feasible.
//
// If the record has previously been saved, this method does an overwriting
// update of the corresponding row, otherwise if it's a new record, this
// method inserts a new row.
//
// For performance, when possible, you should explicitly specify whether
// you are inserting or updating rather than calling this method.
func anyUpsert(transaction: SDSAnyWriteTransaction) {
let isInserting: Bool
if %(class_name)s.anyFetch(uniqueId: uniqueId, transaction: transaction) != nil {
isInserting = false
} else {
isInserting = true
}
sdsSave(saveMode: isInserting ? .insert : .update, transaction: transaction)
}
// This method is used by "updateWith..." methods.
//
// This model may be updated from many threads. We don't want to save
// our local copy (this instance) since it may be out of date. We also
// want to avoid re-saving a model that has been deleted. Therefore, we
// use "updateWith..." methods to:
//
// a) Update a property of this instance.
// b) If a copy of this model exists in the database, load an up-to-date copy,
// and update and save that copy.
// b) If a copy of this model _DOES NOT_ exist in the database, do _NOT_ save
// this local instance.
//
// After "updateWith...":
//
// a) Any copy of this model in the database will have been updated.
// b) The local property on this instance will always have been updated.
// c) Other properties on this instance may be out of date.
//
// All mutable properties of this class have been made read-only to
// prevent accidentally modifying them directly.
//
// This isn't a perfect arrangement, but in practice this will prevent
// data loss and will resolve all known issues.
func anyUpdate(transaction: SDSAnyWriteTransaction, block: (%(class_name)s) -> Void) {
block(self)
guard let dbCopy = type(of: self).anyFetch(uniqueId: uniqueId,
transaction: transaction) else {
return
}
// Don't apply the block twice to the same instance.
// It's at least unnecessary and actually wrong for some blocks.
// e.g. `block: { $0 in $0.someField++ }`
if dbCopy !== self {
block(dbCopy)
}
dbCopy.sdsSave(saveMode: .update, transaction: transaction)
}
// This method is an alternative to `anyUpdate(transaction:block:)` methods.
//
// We should generally use `anyUpdate` to ensure we're not unintentionally
// clobbering other columns in the database when another concurrent update
// has occurred.
//
// There are cases when this doesn't make sense, e.g. when we know we've
// just loaded the model in the same transaction. In those cases it is
// safe and faster to do a "overwriting" update
func anyOverwritingUpdate(transaction: SDSAnyWriteTransaction) {
sdsSave(saveMode: .update, transaction: transaction)
}
""" % {
"class_name": str(clazz.name)
}
if has_remove_methods:
swift_body += """
func anyRemove(transaction: SDSAnyWriteTransaction) {
sdsRemove(transaction: transaction)
}
"""
swift_body += """}
"""
# ---- Cursor ----
swift_body += """
// MARK: - %sCursor
@objc
public class %sCursor: NSObject, SDSCursor {
private let transaction: GRDBReadTransaction
private let cursor: RecordCursor<%s>?
init(transaction: GRDBReadTransaction, cursor: RecordCursor<%s>?) {
self.transaction = transaction
self.cursor = cursor
}
public func next() throws -> %s? {
guard let cursor = cursor else {
return nil
}
guard let record = try cursor.next() else {
return nil
}""" % (
str(clazz.name),
str(clazz.name),
record_name,
record_name,
str(clazz.name),
)
cache_code = cache_set_code_for_class(clazz)
if cache_code is not None:
swift_body += """
let value = try %s.fromRecord(record)
%s(value, transaction: transaction.asAnyRead)
return value""" % (
str(clazz.name),
cache_code,
)
else:
swift_body += """
return try %s.fromRecord(record)""" % (
str(clazz.name),
)
swift_body += """
}
public func all() throws -> [%s] {
var result = [%s]()
while true {
guard let model = try next() else {
break
}
result.append(model)
}
return result
}
}
""" % (
str(clazz.name),
str(clazz.name),
)
# ---- Fetch ----
swift_body += """
// MARK: - Obj-C Fetch
@objc
public extension %(class_name)s {
class func grdbFetchCursor(transaction: GRDBReadTransaction) -> %(class_name)sCursor {
let database = transaction.database
do {
let cursor = try %(record_name)s.fetchCursor(database)
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
} catch {
DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary(
userDefaults: CurrentAppContext().appUserDefaults(),
error: error
)
owsFailDebug("Read failed: \\(error)")
return %(class_name)sCursor(transaction: transaction, cursor: nil)
}
}
""" % {
"class_name": str(clazz.name),
"record_name": record_name,
}
swift_body += """
// Fetches a single model by "unique id".
class func anyFetch(uniqueId: String,
transaction: SDSAnyReadTransaction) -> %(class_name)s? {
assert(!uniqueId.isEmpty)
""" % {
"class_name": str(clazz.name),
"record_name": record_name,
"record_identifier": record_identifier(clazz.name),
}
cache_code = cache_get_code_for_class(clazz)
if cache_code is not None:
swift_body += """
return anyFetch(uniqueId: uniqueId, transaction: transaction, ignoreCache: false)
}
// Fetches a single model by "unique id".
class func anyFetch(uniqueId: String,
transaction: SDSAnyReadTransaction,
ignoreCache: Bool) -> %(class_name)s? {
assert(!uniqueId.isEmpty)
if !ignoreCache,
let cachedCopy = %(cache_code)s {
return cachedCopy
}
""" % {
"class_name": str(clazz.name),
"cache_code": str(cache_code),
}
swift_body += """
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
let sql = "SELECT * FROM \\(%(record_name)s.databaseTableName) WHERE \\(%(record_identifier)sColumn: .uniqueId) = ?"
return grdbFetchOne(sql: sql, arguments: [uniqueId], transaction: grdbTransaction)
}
}
""" % {
"record_name": record_name,
"record_identifier": record_identifier(clazz.name),
}
swift_body += """
// Traverses all records.
// Records are not visited in any particular order.
class func anyEnumerate(
transaction: SDSAnyReadTransaction,
block: (%s, UnsafeMutablePointer<ObjCBool>) -> Void
) {
anyEnumerate(transaction: transaction, batched: false, block: block)
}
// Traverses all records.
// Records are not visited in any particular order.
class func anyEnumerate(
transaction: SDSAnyReadTransaction,
batched: Bool = false,
block: (%s, UnsafeMutablePointer<ObjCBool>) -> Void
) {
let batchSize = batched ? Batching.kDefaultBatchSize : 0
anyEnumerate(transaction: transaction, batchSize: batchSize, block: block)
}
// Traverses all records.
// Records are not visited in any particular order.
//
// If batchSize > 0, the enumeration is performed in autoreleased batches.
class func anyEnumerate(
transaction: SDSAnyReadTransaction,
batchSize: UInt,
block: (%s, UnsafeMutablePointer<ObjCBool>) -> Void
) {
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
let cursor = %s.grdbFetchCursor(transaction: grdbTransaction)
Batching.loop(batchSize: batchSize,
loopBlock: { stop in
do {
guard let value = try cursor.next() else {
stop.pointee = true
return
}
block(value, stop)
} catch let error {
owsFailDebug("Couldn't fetch model: \\(error)")
}
})
}
}
""" % (
(str(clazz.name),) * 4
)
swift_body += '''
// Traverses all records' unique ids.
// Records are not visited in any particular order.
class func anyEnumerateUniqueIds(
transaction: SDSAnyReadTransaction,
block: (String, UnsafeMutablePointer<ObjCBool>) -> Void
) {
anyEnumerateUniqueIds(transaction: transaction, batched: false, block: block)
}
// Traverses all records' unique ids.
// Records are not visited in any particular order.
class func anyEnumerateUniqueIds(
transaction: SDSAnyReadTransaction,
batched: Bool = false,
block: (String, UnsafeMutablePointer<ObjCBool>) -> Void
) {
let batchSize = batched ? Batching.kDefaultBatchSize : 0
anyEnumerateUniqueIds(transaction: transaction, batchSize: batchSize, block: block)
}
// Traverses all records' unique ids.
// Records are not visited in any particular order.
//
// If batchSize > 0, the enumeration is performed in autoreleased batches.
class func anyEnumerateUniqueIds(
transaction: SDSAnyReadTransaction,
batchSize: UInt,
block: (String, UnsafeMutablePointer<ObjCBool>) -> Void
) {
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
grdbEnumerateUniqueIds(transaction: grdbTransaction,
sql: """
SELECT \\(%sColumn: .uniqueId)
FROM \\(%s.databaseTableName)
""",
batchSize: batchSize,
block: block)
}
}
''' % (
record_identifier(clazz.name),
record_name,
)
swift_body += """
// Does not order the results.
class func anyFetchAll(transaction: SDSAnyReadTransaction) -> [%s] {
var result = [%s]()
anyEnumerate(transaction: transaction) { (model, _) in
result.append(model)
}
return result
}
// Does not order the results.
class func anyAllUniqueIds(transaction: SDSAnyReadTransaction) -> [String] {
var result = [String]()
anyEnumerateUniqueIds(transaction: transaction) { (uniqueId, _) in
result.append(uniqueId)
}
return result
}
""" % (
(str(clazz.name),) * 2
)
# ---- Count ----
swift_body += """
class func anyCount(transaction: SDSAnyReadTransaction) -> UInt {
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
return %s.ows_fetchCount(grdbTransaction.database)
}
}
""" % (
record_name,
)
# ---- Remove All ----
if has_remove_methods:
swift_body += """
class func anyRemoveAllWithInstantiation(transaction: SDSAnyWriteTransaction) {
// To avoid mutationDuringEnumerationException, we need to remove the
// instances outside the enumeration.
let uniqueIds = anyAllUniqueIds(transaction: transaction)
for uniqueId in uniqueIds {
autoreleasepool {
guard let instance = anyFetch(uniqueId: uniqueId, transaction: transaction) else {
owsFailDebug("Missing instance.")
return
}
instance.anyRemove(transaction: transaction)
}
}
}
"""
# ---- Exists ----
swift_body += """
class func anyExists(
uniqueId: String,
transaction: SDSAnyReadTransaction
) -> Bool {
assert(!uniqueId.isEmpty)
switch transaction.readTransaction {
case .grdbRead(let grdbTransaction):
let sql = "SELECT EXISTS ( SELECT 1 FROM \\(%s.databaseTableName) WHERE \\(%sColumn: .uniqueId) = ? )"
let arguments: StatementArguments = [uniqueId]
do {
return try Bool.fetchOne(grdbTransaction.database, sql: sql, arguments: arguments) ?? false
} catch {
DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary(
userDefaults: CurrentAppContext().appUserDefaults(),
error: error
)
owsFail("Missing instance.")
}
}
}
}
""" % (
record_name,
record_identifier(clazz.name),
)
# ---- Fetch ----
swift_body += """
// MARK: - Swift Fetch
public extension %(class_name)s {
class func grdbFetchCursor(sql: String,
arguments: StatementArguments = StatementArguments(),
transaction: GRDBReadTransaction) -> %(class_name)sCursor {
do {
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
let cursor = try %(record_name)s.fetchCursor(transaction.database, sqlRequest)
return %(class_name)sCursor(transaction: transaction, cursor: cursor)
} catch {
DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary(
userDefaults: CurrentAppContext().appUserDefaults(),
error: error
)
owsFailDebug("Read failed: \\(error)")
return %(class_name)sCursor(transaction: transaction, cursor: nil)
}
}
""" % {
"class_name": str(clazz.name),
"record_name": record_name,
}
string_interpolation_name = remove_prefix_from_class_name(clazz.name)
swift_body += """
class func grdbFetchOne(sql: String,
arguments: StatementArguments = StatementArguments(),
transaction: GRDBReadTransaction) -> %s? {
assert(!sql.isEmpty)
do {
let sqlRequest = SQLRequest<Void>(sql: sql, arguments: arguments, cached: true)
guard let record = try %s.fetchOne(transaction.database, sqlRequest) else {
return nil
}
""" % (
str(clazz.name),
record_name,
)
cache_code = cache_set_code_for_class(clazz)
if cache_code is not None:
swift_body += """
let value = try %s.fromRecord(record)
%s(value, transaction: transaction.asAnyRead)
return value""" % (
str(clazz.name),
cache_code,
)
else:
swift_body += """
return try %s.fromRecord(record)""" % (
str(clazz.name),
)
swift_body += """
} catch {
owsFailDebug("error: \\(error)")
return nil
}
}
}
"""
# ---- Typed Convenience Methods ----
if has_sds_superclass:
swift_body += """
// MARK: - Typed Convenience Methods
@objc
public extension %s {
// NOTE: This method will fail if the object has unexpected type.
class func anyFetch%s(
uniqueId: String,
transaction: SDSAnyReadTransaction
) -> %s? {
assert(!uniqueId.isEmpty)
guard let object = anyFetch(uniqueId: uniqueId,
transaction: transaction) else {
return nil
}
guard let instance = object as? %s else {
owsFailDebug("Object has unexpected type: \\(type(of: object))")
return nil
}
return instance
}
// NOTE: This method will fail if the object has unexpected type.
func anyUpdate%s(transaction: SDSAnyWriteTransaction, block: (%s) -> Void) {
anyUpdate(transaction: transaction) { (object) in
guard let instance = object as? %s else {
owsFailDebug("Object has unexpected type: \\(type(of: object))")
return
}
block(instance)
}
}
}
""" % (
str(clazz.name),
str(remove_prefix_from_class_name(clazz.name)),
str(clazz.name),
str(clazz.name),
str(remove_prefix_from_class_name(clazz.name)),
str(clazz.name),
str(clazz.name),
)
# ---- SDSModel ----
table_superclass = clazz.table_superclass()
table_class_name = str(table_superclass.name)
has_serializable_superclass = table_superclass.name != clazz.name
override_keyword = ""
swift_body += """
// MARK: - SDSSerializer
// The SDSSerializer protocol specifies how to insert and update the
// row that corresponds to this model.
class %sSerializer: SDSSerializer {
private let model: %s
public init(model: %s) {
self.model = model
}
""" % (
str(clazz.name),
str(clazz.name),
str(clazz.name),
)
# --- To Record
root_class = clazz.table_superclass()
root_record_name = remove_prefix_from_class_name(root_class.name) + "Record"
record_id_source = "model.grdbId?.int64Value"
if root_class.record_id_source() is not None:
record_id_source = (
"model.%(source)s > 0 ? Int64(model.%(source)s) : %(default_source)s"
% {
"source": root_class.record_id_source(),
"default_source": record_id_source,
}
)
swift_body += """
// MARK: - Record
func asRecord() -> SDSRecord {
let id: Int64? = %(record_id_source)s
let recordType: SDSRecordType = .%(record_type)s
let uniqueId: String = model.uniqueId
""" % {
"record_type": get_record_type_enum_name(clazz.name),
"record_id_source": record_id_source,
}
initializer_args = [
"id",
"recordType",
"uniqueId",
]
inherited_property_map = {}
for property in properties_and_inherited_properties(clazz):
inherited_property_map[property.column_name()] = property
def write_record_property(property, force_optional=False):
optional_value = ""
if property.column_name() in inherited_property_map:
inherited_property = inherited_property_map[property.column_name()]
did_force_optional = property.force_optional
model_accessor = accessor_name_for_property(inherited_property)
value_expr = inherited_property.serialize_record_invocation(
"model.%s" % (model_accessor,), did_force_optional
)
optional_value = " = %s" % (value_expr,)
else:
optional_value = " = nil"
record_field_type = property.record_field_type()
is_optional = property.is_optional or force_optional
optional_split = "?" if is_optional else ""
initializer_args.append(property.column_name())
return """ let %s: %s%s%s
""" % (
str(property.column_name()),
record_field_type,
optional_split,
optional_value,
)
root_record_properties = root_class.sorted_record_properties()
if len(root_record_properties) > 0:
swift_body += "\n // Properties \n"
for property in root_record_properties:
swift_body += write_record_property(
property, force_optional=property.force_optional
)
initializer_args = [
"%s: %s"
% (
arg,
arg,
)
for arg in initializer_args
]
swift_body += """
return %s(delegate: model, %s)
}
""" % (
root_record_name,
", ".join(initializer_args),
)
swift_body += """}
"""
print(f"Writing {swift_filename}")
swift_body = sds_common.clean_up_generated_swift(swift_body)
sds_common.write_text_file_if_changed(swift_filepath, swift_body)
def process_class_map(class_map):
for clazz in class_map.values():
generate_swift_extensions_for_model(clazz)
# ---- Record Type Map
record_type_map = {}
# It's critical that our "record type" values are consistent, even if we add/remove/rename model classes.
# Therefore we persist the mapping of known classes in a JSON file that is under source control.
def update_record_type_map(record_type_swift_path, record_type_json_path):
record_type_map_filepath = record_type_json_path
if os.path.exists(record_type_map_filepath):
with open(record_type_map_filepath, "rt") as f:
json_string = f.read()
json_data = json.loads(json_string)
record_type_map.update(json_data)
max_record_type = 0
for class_name in record_type_map:
if class_name.startswith("#"):
continue
record_type = record_type_map[class_name]
max_record_type = max(max_record_type, record_type)
for clazz in global_class_map.values():
if clazz.name not in record_type_map:
if not clazz.should_generate_extensions():
continue
max_record_type = int(max_record_type) + 1
record_type = max_record_type
record_type_map[clazz.name] = record_type
record_type_map["#comment"] = (
"NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
% (sds_common.pretty_module_path(__file__),)
)
json_string = json.dumps(record_type_map, sort_keys=True, indent=4)
sds_common.write_text_file_if_changed(record_type_map_filepath, json_string)
# TODO: We'll need to import SignalServiceKit for non-SSK classes.
swift_body = """//
// Copyright 2022 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import Foundation
import GRDB
// NOTE: This file is generated by %s.
// Do not manually edit it, instead run `sds_codegen.sh`.
@objc
public enum SDSRecordType: UInt, CaseIterable {
""" % (
sds_common.pretty_module_path(__file__),
)
record_type_pairs = []
for key in record_type_map.keys():
if key.startswith("#"):
# Ignore comments
continue
enum_name = get_record_type_enum_name(key)
record_type_pairs.append((str(enum_name), record_type_map[key]))
record_type_pairs.sort(key=lambda value: value[1])
for enum_name, record_type_id in record_type_pairs:
swift_body += """ case %s = %s
""" % (
enum_name,
str(record_type_id),
)
swift_body += """}
"""
swift_body = sds_common.clean_up_generated_swift(swift_body)
sds_common.write_text_file_if_changed(record_type_swift_path, swift_body)
def get_record_type(clazz):
return record_type_map[clazz.name]
def remove_prefix_from_class_name(class_name):
name = class_name
if name.startswith("TS"):
name = name[len("TS") :]
elif name.startswith("OWS"):
name = name[len("OWS") :]
elif name.startswith("SSK"):
name = name[len("SSK") :]
return name
def get_record_type_enum_name(class_name):
name = remove_prefix_from_class_name(class_name)
if name[0].isnumeric():
name = "_" + name
return to_swift_identifier_name(name)
def record_identifier(class_name):
name = remove_prefix_from_class_name(class_name)
return to_swift_identifier_name(name)
# ---- Column Ordering
column_ordering_map = {}
has_loaded_column_ordering_map = False
# ---- Parsing
enum_type_map = {}
def objc_type_for_enum(enum_name):
if enum_name not in enum_type_map:
print("enum_type_map", enum_type_map)
fail("Enum has unknown type:", enum_name)
enum_type = enum_type_map[enum_name]
return enum_type
def swift_type_for_enum(enum_name):
objc_type = objc_type_for_enum(enum_name)
if objc_type == "NSInteger":
return "Int"
elif objc_type == "NSUInteger":
return "UInt"
elif objc_type == "int32_t":
return "Int32"
elif objc_type == "unsigned long long":
return "uint64_t"
elif objc_type == "unsigned long long":
return "UInt64"
elif objc_type == "unsigned long":
return "UInt64"
elif objc_type == "unsigned int":
return "UInt"
else:
fail("Unknown objc type:", objc_type)
def parse_sds_json(file_path):
with open(file_path, "rt") as f:
json_str = f.read()
json_data = json.loads(json_str)
classes = json_data["classes"]
class_map = {}
for class_dict in classes:
clazz = ParsedClass(class_dict)
class_map[clazz.name] = clazz
enums = json_data["enums"]
enum_type_map.update(enums)
return class_map
def try_to_parse_file(file_path):
filename = os.path.basename(file_path)
_, file_extension = os.path.splitext(filename)
if filename.endswith(sds_common.SDS_JSON_FILE_EXTENSION):
return parse_sds_json(file_path)
else:
return {}
def find_sds_intermediary_files_in_path(path):
class_map = {}
if os.path.isfile(path):
class_map.update(try_to_parse_file(path))
else:
for rootdir, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.abspath(os.path.join(rootdir, filename))
class_map.update(try_to_parse_file(file_path))
return class_map
def update_subclass_map():
for clazz in global_class_map.values():
if clazz.super_class_name is not None:
subclasses = global_subclass_map.get(clazz.super_class_name, [])
subclasses.append(clazz)
global_subclass_map[clazz.super_class_name] = subclasses
def all_descendents_of_class(clazz):
result = []
subclasses = global_subclass_map.get(clazz.name, [])
subclasses.sort(key=lambda value: value.name)
for subclass in subclasses:
result.append(subclass)
result.extend(all_descendents_of_class(subclass))
return result
def is_swift_class_name(swift_type):
return global_class_map.get(swift_type) is not None
# ---- Config JSON
configuration_json = {}
def parse_config_json(config_json_path):
with open(config_json_path, "rt") as f:
json_str = f.read()
json_data = json.loads(json_str)
global configuration_json
configuration_json = json_data
# We often use nullable NSNumber * for optional numerics (bool, int, int64, double, etc.).
# There's now way to infer which type we're boxing in NSNumber.
# Therefore, we need to specify that in the configuration JSON.
def swift_type_for_nsnumber(property):
nsnumber_types = configuration_json.get("nsnumber_types")
if nsnumber_types is None:
print("Suggestion: update: %s" % (str(global_args.config_json_path),))
fail("Configuration JSON is missing mapping for properties of type NSNumber.")
key = property.class_name + "." + property.name
swift_type = nsnumber_types.get(key)
if swift_type is None:
print("Suggestion: update: %s" % (str(global_args.config_json_path),))
fail(
"Configuration JSON is missing mapping for properties of type NSNumber:",
key,
)
return swift_type
# Some properties shouldn't get serialized.
# For now, there's just one: TSGroupModel.groupImage which is a UIImage.
# We might end up extending the serialization to handle images.
# Or we might store these as Data/NSData/blob.
# TODO:
def should_ignore_property(property):
properties_to_ignore = configuration_json.get("properties_to_ignore")
if properties_to_ignore is None:
fail(
"Configuration JSON is missing list of properties to ignore during serialization."
)
key = property.class_name + "." + property.name
return key in properties_to_ignore
def cache_get_code_for_class(clazz):
code_map = configuration_json.get("class_cache_get_code")
if code_map is None:
fail("Configuration JSON is missing dict of class_cache_get_code.")
key = clazz.name
return code_map.get(key)
def cache_set_code_for_class(clazz):
code_map = configuration_json.get("class_cache_set_code")
if code_map is None:
fail("Configuration JSON is missing dict of class_cache_set_code.")
key = clazz.name
return code_map.get(key)
def should_ignore_class(clazz):
class_to_skip_serialization = configuration_json.get("class_to_skip_serialization")
if class_to_skip_serialization is None:
fail(
"Configuration JSON is missing list of classes to ignore during serialization."
)
if clazz.name in class_to_skip_serialization:
return True
if clazz.super_class_name is None:
return False
if not clazz.super_class_name in global_class_map:
return False
super_clazz = global_class_map[clazz.super_class_name]
return should_ignore_class(super_clazz)
def accessor_name_for_property(property):
custom_accessors = configuration_json.get("custom_accessors")
if custom_accessors is None:
fail("Configuration JSON is missing list of custom property accessors.")
key = property.class_name + "." + property.name
return custom_accessors.get(key, property.name)
# include_renamed_columns
def custom_column_name_for_property(property):
custom_column_names = configuration_json.get("custom_column_names")
if custom_column_names is None:
fail("Configuration JSON is missing list of custom column names.")
key = property.class_name + "." + property.name
return custom_column_names.get(key)
def aliased_column_name_for_property(property):
custom_column_names = configuration_json.get("aliased_column_names")
if custom_column_names is None:
fail("Configuration JSON is missing dict of aliased_column_names.")
key = property.class_name + "." + property.name
return custom_column_names.get(key)
def was_property_renamed_for_property(property):
renamed_column_names = configuration_json.get("renamed_column_names")
if renamed_column_names is None:
fail("Configuration JSON is missing list of renamed column names.")
key = property.class_name + "." + property.name
return renamed_column_names.get(key) is not None
# ---- Config JSON
property_order_json = {}
def parse_property_order_json(property_order_json_path):
with open(property_order_json_path, "rt") as f:
json_str = f.read()
json_data = json.loads(json_str)
global property_order_json
property_order_json = json_data
# It's critical that our "property order" is consistent, even if we add columns.
# Therefore we persist the "property order" for all known properties in a JSON file that is under source control.
def update_property_order_json(property_order_json_path):
property_order_json["#comment"] = (
"NOTE: This file is generated by %s. Do not manually edit it, instead run `sds_codegen.sh`."
% (sds_common.pretty_module_path(__file__),)
)
json_string = json.dumps(property_order_json, sort_keys=True, indent=4)
sds_common.write_text_file_if_changed(property_order_json_path, json_string)
def property_order_key(property, record_name):
return record_name + "." + property.name
def property_order_for_property(property, record_name):
key = property_order_key(property, record_name)
result = property_order_json.get(key)
return result
def set_property_order_for_property(property, record_name, value):
key = property_order_key(property, record_name)
property_order_json[key] = value
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate Swift extensions.")
parser.add_argument(
"--src-path", required=True, help="used to specify a path to process."
)
parser.add_argument(
"--search-path", required=True, help="used to specify a path to process."
)
parser.add_argument(
"--record-type-swift-path",
required=True,
help="path of the record type enum swift file.",
)
parser.add_argument(
"--record-type-json-path",
required=True,
help="path of the record type map json file.",
)
parser.add_argument(
"--config-json-path",
required=True,
help="path of the json file with code generation config info.",
)
parser.add_argument(
"--property-order-json-path",
required=True,
help="path of the json file with property ordering cache.",
)
args = parser.parse_args()
global_args = args
src_path = os.path.abspath(args.src_path)
search_path = os.path.abspath(args.search_path)
record_type_swift_path = os.path.abspath(args.record_type_swift_path)
record_type_json_path = os.path.abspath(args.record_type_json_path)
config_json_path = os.path.abspath(args.config_json_path)
property_order_json_path = os.path.abspath(args.property_order_json_path)
# We control the code generation process using a JSON config file.
parse_config_json(config_json_path)
parse_property_order_json(property_order_json_path)
# The code generation needs to understand the class hierarchy so that
# it can:
#
# * Define table schemas that include the superset of properties in
# the model class hierarchies.
# * Generate deserialization methods that handle all subclasses.
# * etc.
global_class_map.update(find_sds_intermediary_files_in_path(search_path))
update_subclass_map()
update_record_type_map(record_type_swift_path, record_type_json_path)
process_class_map(find_sds_intermediary_files_in_path(src_path))
# Persist updated property order
update_property_order_json(property_order_json_path)