TM-SGNL-iOS/SignalServiceKit/Storage/Database/KeyValueStore.swift
TeleMessage developers dde0620daf initial commit
2025-05-03 12:28:28 -07:00

556 lines
20 KiB
Swift

//
// Copyright 2024 Signal Messenger, LLC
// SPDX-License-Identifier: AGPL-3.0-only
//
import GRDB
public struct KeyValueStore {
private enum TableMetadata {
enum Columns {
static let collection = "collection"
static let key = "key"
static let value = "value"
}
static let tableName = "keyvalue"
}
static var tableName: String { TableMetadata.tableName }
static var collectionColumnName: String { TableMetadata.Columns.collection }
private let collection: String
public init(collection: String) {
self.collection = collection
}
public static func logCollectionStatistics() {
Logger.info("KeyValueStore statistics:")
SSKEnvironment.shared.databaseStorageRef.read { transaction in
do {
let sql = """
SELECT \(TableMetadata.Columns.collection), COUNT(*)
FROM \(TableMetadata.tableName)
GROUP BY \(TableMetadata.Columns.collection)
ORDER BY COUNT(*) DESC
LIMIT 10
"""
let cursor = try Row.fetchCursor(transaction.unwrapGrdbRead.database, sql: sql)
while let row = try cursor.next() {
let collection: String = row[0]
let count: UInt = row[1]
Logger.info("- \(collection): \(count) items")
}
} catch {
Logger.error("\(error)")
}
}
}
// MARK: -
public func hasValue(_ key: String, transaction: DBReadTransaction) -> Bool {
do {
let count = try UInt.fetchOne(
transaction.databaseConnection,
sql: """
SELECT
COUNT(*)
FROM \(TableMetadata.tableName)
WHERE \(TableMetadata.Columns.key) = ?
AND \(TableMetadata.Columns.collection) == ?
""",
arguments: [key, collection]
) ?? 0
return count > 0
} catch {
owsFailDebug("error: \(error)")
return false
}
}
// MARK: - String
public func getString(_ key: String, transaction: DBReadTransaction) -> String? {
return getObject(key, ofClass: NSString.self, transaction: transaction) as String?
}
public func setString(_ value: String?, key: String, transaction: DBWriteTransaction) {
guard let value = value else {
write(nil, forKey: key, transaction: transaction)
return
}
write(value as NSString, forKey: key, transaction: transaction)
}
// MARK: - Date
public func getDate(_ key: String, transaction: DBReadTransaction) -> Date? {
// Our legacy methods sometimes stored dates as NSNumber and
// sometimes as NSDate, so we are permissive when decoding.
let value = getObject(key, ofClasses: [NSDate.self, NSNumber.self], transaction: transaction)
if let dateValue = value as? NSDate {
return dateValue as Date
}
if let numberValue = value as? NSNumber {
return Date(timeIntervalSince1970: numberValue.doubleValue)
}
return nil
}
public func setDate(_ value: Date, key: String, transaction: DBWriteTransaction) {
let epochInterval = NSNumber(value: value.timeIntervalSince1970)
setObject(epochInterval, key: key, transaction: transaction)
}
// MARK: - Bool
public func getBool(_ key: String, transaction: DBReadTransaction) -> Bool? {
return getNSNumber(key, transaction: transaction)?.boolValue
}
public func getBool(_ key: String, defaultValue: Bool, transaction: DBReadTransaction) -> Bool {
return getBool(key, transaction: transaction) ?? defaultValue
}
public func setBool(_ value: Bool, key: String, transaction: DBWriteTransaction) {
write(NSNumber(value: value), forKey: key, transaction: transaction)
}
// MARK: - UInt
public func getUInt(_ key: String, transaction: DBReadTransaction) -> UInt? {
return getNSNumber(key, transaction: transaction)?.uintValue
}
public func getUInt(_ key: String, defaultValue: UInt, transaction: DBReadTransaction) -> UInt {
return getUInt(key, transaction: transaction) ?? defaultValue
}
public func setUInt(_ value: UInt, key: String, transaction: DBWriteTransaction) {
write(NSNumber(value: value), forKey: key, transaction: transaction)
}
// MARK: - Data
public func getData(_ key: String, transaction: DBReadTransaction) -> Data? {
do {
return try Data.fetchOne(
transaction.databaseConnection,
sql: """
SELECT \(TableMetadata.Columns.value)
FROM \(TableMetadata.tableName)
WHERE
\(TableMetadata.Columns.key) = ?
AND \(TableMetadata.Columns.collection) == ?
""",
arguments: [key, collection]
)
} catch {
DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary(
userDefaults: CurrentAppContext().appUserDefaults(),
error: error
)
owsFailDebug("error: \(error)")
return nil
}
}
public func setData(_ data: Data?, key: String, transaction: DBWriteTransaction) {
do {
let sql: String
let arguments: StatementArguments
if let data {
// See: https://www.sqlite.org/lang_UPSERT.html
sql = """
INSERT INTO \(TableMetadata.tableName) (
\(TableMetadata.Columns.key),
\(TableMetadata.Columns.collection),
\(TableMetadata.Columns.value)
) VALUES (?, ?, ?)
ON CONFLICT (
\(TableMetadata.Columns.key),
\(TableMetadata.Columns.collection)
) DO UPDATE
SET \(TableMetadata.Columns.value) = ?
"""
arguments = [key, collection, data, data]
} else {
// Setting to nil is a delete.
sql = """
DELETE FROM \(TableMetadata.tableName)
WHERE
\(TableMetadata.Columns.key) == ?
AND \(TableMetadata.Columns.collection) == ?
"""
arguments = [key, collection]
}
let statement = try transaction.databaseConnection.cachedStatement(sql: sql)
try statement.setArguments(arguments)
do {
try statement.execute()
} catch {
DatabaseCorruptionState.flagDatabaseCorruptionIfNecessary(
userDefaults: CurrentAppContext().appUserDefaults(),
error: error
)
throw error
}
} catch {
owsFailDebug("Error: \(error)")
}
}
// MARK: - Int
public func getInt(_ key: String, transaction: DBReadTransaction) -> Int? {
return getNSNumber(key, transaction: transaction)?.intValue
}
public func getInt(_ key: String, defaultValue: Int, transaction: DBReadTransaction) -> Int {
return getInt(key, transaction: transaction) ?? defaultValue
}
public func setInt(_ value: Int, key: String, transaction: DBWriteTransaction) {
setObject(NSNumber(value: value), key: key, transaction: transaction)
}
// MARK: - Int32
public func getInt32(_ key: String, transaction: DBReadTransaction) -> Int32? {
return getNSNumber(key, transaction: transaction)?.int32Value
}
public func getInt32(_ key: String, defaultValue: Int32, transaction: DBReadTransaction) -> Int32 {
return getInt32(key, transaction: transaction) ?? defaultValue
}
public func setInt32(_ value: Int32, key: String, transaction: DBWriteTransaction) {
setObject(NSNumber(value: value), key: key, transaction: transaction)
}
// MARK: - UInt32
public func getUInt32(_ key: String, transaction: DBReadTransaction) -> UInt32? {
return getNSNumber(key, transaction: transaction)?.uint32Value
}
public func getUInt32(_ key: String, defaultValue: UInt32, transaction: DBReadTransaction) -> UInt32 {
return getUInt32(key, transaction: transaction) ?? defaultValue
}
public func setUInt32(_ value: UInt32, key: String, transaction: DBWriteTransaction) {
setObject(NSNumber(value: value), key: key, transaction: transaction)
}
// MARK: - UInt64
public func getUInt64(_ key: String, transaction: DBReadTransaction) -> UInt64? {
return getNSNumber(key, transaction: transaction)?.uint64Value
}
public func getUInt64(_ key: String, defaultValue: UInt64, transaction: DBReadTransaction) -> UInt64 {
return getUInt64(key, transaction: transaction) ?? defaultValue
}
public func setUInt64(_ value: UInt64, key: String, transaction: DBWriteTransaction) {
setObject(NSNumber(value: value), key: key, transaction: transaction)
}
// MARK: - Int64
public func getInt64(_ key: String, transaction: DBReadTransaction) -> Int64? {
return getNSNumber(key, transaction: transaction)?.int64Value
}
public func getInt64(_ key: String, defaultValue: Int64, transaction: DBReadTransaction) -> Int64 {
return getInt64(key, transaction: transaction) ?? defaultValue
}
public func setInt64(_ value: Int64, key: String, transaction: DBWriteTransaction) {
setObject(NSNumber(value: value), key: key, transaction: transaction)
}
// MARK: - Double
public func getDouble(_ key: String, transaction: DBReadTransaction) -> Double? {
return getNSNumber(key, transaction: transaction)?.doubleValue
}
public func getDouble(_ key: String, defaultValue: Double, transaction: DBReadTransaction) -> Double {
return getDouble(key, transaction: transaction) ?? defaultValue
}
public func setDouble(_ value: Double, key: String, transaction: DBWriteTransaction) {
setObject(NSNumber(value: value), key: key, transaction: transaction)
}
// MARK: - Object
public func setObject(_ anyValue: Any?, key: String, transaction: DBWriteTransaction) {
guard let anyValue = anyValue else {
write(nil, forKey: key, transaction: transaction)
return
}
guard let codingValue = anyValue as? NSCoding else {
owsFailDebug("Invalid value.")
write(nil, forKey: key, transaction: transaction)
return
}
write(codingValue, forKey: key, transaction: transaction)
}
// MARK: -
public func removeValue(forKey key: String, transaction: DBWriteTransaction) {
write(nil, forKey: key, transaction: transaction)
}
public func removeValues(forKeys keys: [String], transaction: DBWriteTransaction) {
for key in keys {
write(nil, forKey: key, transaction: transaction)
}
}
public func removeAll(transaction: DBWriteTransaction) {
transaction.databaseConnection.executeAndCacheStatementHandlingErrors(
sql: """
DELETE
FROM \(TableMetadata.tableName)
WHERE \(TableMetadata.Columns.collection) == ?
""",
arguments: [collection]
)
}
public func allDataValues(transaction: DBReadTransaction) -> [Data] {
return allKeys(transaction: transaction).compactMap { key in
return self.getData(key, transaction: transaction)
}
}
// MARK: -
private struct PairRecord: Codable, FetchableRecord, PersistableRecord {
let key: String
let value: Data
}
public func allUIntValuesMap(transaction: DBReadTransaction) -> [String: UInt] {
let allPairs: [PairRecord] = {
do {
let sql = """
SELECT
\(TableMetadata.Columns.key),
\(TableMetadata.Columns.value)
FROM \(TableMetadata.tableName)
WHERE \(TableMetadata.Columns.collection) == ?
"""
return try PairRecord.fetchAll(
transaction.databaseConnection,
sql: sql,
arguments: [collection]
)
} catch {
owsFailDebug("Error: \(error)")
return []
}
}()
var result = [String: UInt]()
for pair in allPairs {
guard let numberValue = parseArchivedValue(pair.value, ofClass: NSNumber.self) else {
owsFailDebug("Could not parse value.")
continue
}
result[pair.key] = numberValue.uintValue
}
return result
}
// MARK: -
public func anyDataValue(transaction: DBReadTransaction) -> Data? {
let randomKey = allKeys(transaction: transaction).randomElement()
guard let randomKey else {
return nil
}
guard let dataValue = self.getData(randomKey, transaction: transaction) else {
owsFailDebug("Couldn't fetch random element")
return nil
}
return dataValue
}
public func allKeys(transaction: DBReadTransaction) -> [String] {
let sql = """
SELECT \(TableMetadata.Columns.key)
FROM \(TableMetadata.tableName)
WHERE \(TableMetadata.Columns.collection) == ?
"""
return transaction.databaseConnection.strictRead { database in
try String.fetchAll(database, sql: sql, arguments: [collection])
}
}
public func numberOfKeys(transaction: DBReadTransaction) -> UInt {
do {
let sql = """
SELECT COUNT(*)
FROM \(TableMetadata.tableName)
WHERE \(TableMetadata.Columns.collection) == ?
"""
guard let numberOfKeys = try UInt.fetchOne(
transaction.databaseConnection,
sql: sql,
arguments: [collection]
) else {
throw OWSAssertionError("numberOfKeys was unexpectedly nil")
}
return numberOfKeys
} catch {
DatabaseCorruptionState.flagDatabaseReadCorruptionIfNecessary(
userDefaults: CurrentAppContext().appUserDefaults(),
error: error
)
owsFail("error: \(error)")
}
}
// MARK: -
@available(*, deprecated, message: "Did you mean removeValue(forKey:transaction:) or setCodable(optional:key:transaction)?")
public func setCodable<T: Encodable>(_ value: T?, key: String, transaction: DBWriteTransaction) throws {
try setCodable(optional: value, key: key, transaction: transaction)
}
public func setCodable<T: Encodable>(_ value: T, key: String, transaction: DBWriteTransaction) throws {
// The only difference between setCodable(optional:...) and setCodable(_...) is
// the non-optional variant has a deprecated overload to warn callers of the ambiguity.
try setCodable(optional: value, key: key, transaction: transaction)
}
public func setCodable<T: Encodable>(optional value: T, key: String, transaction: DBWriteTransaction) throws {
do {
let data = try JSONEncoder().encode(value)
setData(data, key: key, transaction: transaction)
} catch {
owsFailDebug("Failed to encode: \(error).")
throw error
}
}
public func getCodableValue<T: Decodable>(forKey key: String, transaction: DBReadTransaction) throws -> T? {
guard let data = getData(key, transaction: transaction) else {
return nil
}
do {
return try JSONDecoder().decode(T.self, from: data)
} catch {
owsFailDebug("Failed to decode: \(error).")
throw error
}
}
public func allCodableValues<T: Decodable>(transaction: DBReadTransaction) throws -> [T] {
var result = [T]()
for data in allDataValues(transaction: transaction) {
do {
result.append(try JSONDecoder().decode(T.self, from: data))
} catch {
owsFailDebug("Failed to decode: \(error).")
throw error
}
}
return result
}
// MARK: - Archived Values
public func getObject<ObjectType: NSObject & NSSecureCoding>(_ key: String, ofClass cls: ObjectType.Type, transaction: any DBReadTransaction) -> ObjectType? {
return self.getData(key, transaction: transaction).flatMap { self.parseArchivedValue($0, ofClass: cls) }
}
public func getObject(_ key: String, ofClasses classes: [AnyClass], transaction: any DBReadTransaction) -> Any? {
return self.getData(key, transaction: transaction).flatMap {
do {
return try NSKeyedUnarchiver.unarchivedObject(ofClasses: classes, from: $0)
} catch {
owsFailDebug("Couldn't decode value.")
return nil
}
}
}
private func parseArchivedValue<DecodedObject: NSObject & NSSecureCoding>(_ encodedData: Data, ofClass cls: DecodedObject.Type) -> DecodedObject? {
do {
return try NSKeyedUnarchiver.unarchivedObject(ofClass: cls, from: encodedData)
} catch {
owsFailDebug("Couldn't decode value.")
return nil
}
}
private func getNSNumber(_ key: String, transaction: any DBReadTransaction) -> NSNumber? {
return getObject(key, ofClass: NSNumber.self, transaction: transaction)
}
public func getDictionary<DecodedKey: NSObject & NSCopying & NSSecureCoding, DecodedObject: NSObject & NSSecureCoding>(
_ key: String,
keyClass: DecodedKey.Type,
objectClass: DecodedObject.Type,
transaction: any DBReadTransaction
) -> [DecodedKey: DecodedObject]? {
return self.getData(key, transaction: transaction).flatMap {
do {
return try NSKeyedUnarchiver.unarchivedDictionary(ofKeyClass: keyClass, objectClass: objectClass, from: $0)
} catch {
owsFailDebug("Decode failed.")
return nil
}
}
}
public func getArray<DecodedObject: NSObject & NSSecureCoding>(_ key: String, ofClass cls: DecodedObject.Type, transaction: any DBReadTransaction) -> [DecodedObject]? {
return self.getData(key, transaction: transaction).flatMap {
do {
return try NSKeyedUnarchiver.unarchivedArrayOfObjects(ofClass: cls, from: $0)
} catch {
owsFailDebug("Decode failed.")
return nil
}
}
}
public func getStringArray(_ key: String, transaction: any DBReadTransaction) -> [String]? {
return self.getArray(key, ofClass: NSString.self, transaction: transaction) as [String]?
}
public func getSet<DecodedObject: NSObject & NSSecureCoding>(_ key: String, ofClass cls: DecodedObject.Type, transaction: any DBReadTransaction) -> Set<DecodedObject>? {
return self.getObject(key, ofClasses: [NSSet.self, cls], transaction: transaction) as? Set<DecodedObject>
}
// MARK: - Internal Methods
private func write(
_ value: NSCoding?,
forKey key: String,
transaction: DBWriteTransaction
) {
let encoded: Data? = value.flatMap {
try? NSKeyedArchiver.archivedData(
withRootObject: $0,
requiringSecureCoding: false
)
}
setData(encoded, key: key, transaction: transaction)
}
}