From 37c101d175ddca390dc48014a5087c8291772b12 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 26 Apr 2019 14:58:53 -0700 Subject: [PATCH] Updates the TensorFlow Lite Swift library to Swift 5. PiperOrigin-RevId: 245493147 --- .../swift/Sources/Interpreter.swift | 28 ++++++++++--------- .../swift/Sources/InterpreterError.swift | 28 ------------------- .../TensorFlowLiteApp/AppDelegate.swift | 8 ------ .../Array+TensorFlowLite.swift | 11 +++++--- .../swift/Tests/InterpreterTests.swift | 14 +++++++--- .../experimental/swift/Tests/ModelTests.swift | 6 ++-- 6 files changed, 35 insertions(+), 60 deletions(-) diff --git a/tensorflow/lite/experimental/swift/Sources/Interpreter.swift b/tensorflow/lite/experimental/swift/Sources/Interpreter.swift index 47ea935322e..c1aea0a067e 100644 --- a/tensorflow/lite/experimental/swift/Sources/Interpreter.swift +++ b/tensorflow/lite/experimental/swift/Sources/Interpreter.swift @@ -87,7 +87,6 @@ public final class Interpreter { /// - Throws: An error if the model was not ready because tensors were not allocated. public func invoke() throws { guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else { - // TODO(b/117510052): Determine which error to throw. throw InterpreterError.allocateTensorsRequired } } @@ -104,8 +103,8 @@ public final class Interpreter { throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) } guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)), - let bytes = TFL_TensorData(cTensor), - let nameCString = TFL_TensorName(cTensor) + let bytes = TFL_TensorData(cTensor), + let nameCString = TFL_TensorName(cTensor) else { throw InterpreterError.allocateTensorsRequired } @@ -124,7 +123,6 @@ public final class Interpreter { let zeroPoint = Int(cQuantizationParams.zero_point) var quantizationParameters: QuantizationParameters? = nil if scale != 0.0 { - // TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode. quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint) } let tensor = Tensor( @@ -151,10 +149,9 @@ public final class Interpreter { throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) } guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)), - let bytes = TFL_TensorData(cTensor), - let nameCString = TFL_TensorName(cTensor) + let bytes = TFL_TensorData(cTensor), + let nameCString = TFL_TensorName(cTensor) else { - // TODO(b/117510052): Determine which error to throw. throw InterpreterError.invokeInterpreterRequired } guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else { @@ -172,7 +169,6 @@ public final class Interpreter { let zeroPoint = Int(cQuantizationParams.zero_point) var quantizationParameters: QuantizationParameters? = nil if scale != 0.0 { - // TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode. quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint) } let tensor = Tensor( @@ -200,11 +196,11 @@ public final class Interpreter { throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) } guard TFL_InterpreterResizeInputTensor( - cInterpreter, - Int32(index), - shape.int32Dimensions, - Int32(shape.rank) - ) == kTfLiteOk + cInterpreter, + Int32(index), + shape.int32Dimensions, + Int32(shape.rank) + ) == kTfLiteOk else { throw InterpreterError.failedToResizeInputTensor(index: index) } @@ -233,7 +229,13 @@ public final class Interpreter { throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount) } + #if swift(>=5.0) + let status = data.withUnsafeBytes { + TFL_TensorCopyFromBuffer(cTensor, $0.baseAddress, data.count) + } + #else let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) } + #endif // swift(>=5.0) guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor } return try input(at: index) } diff --git a/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift b/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift index 5de58b997a7..b9dc01c3a9a 100644 --- a/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift +++ b/tensorflow/lite/experimental/swift/Sources/InterpreterError.swift @@ -68,32 +68,4 @@ extension InterpreterError: CustomStringConvertible { } } -#if swift(>=4.2) extension InterpreterError: Equatable {} -#else -extension InterpreterError: Equatable { - public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool { - switch (lhs, rhs) { - case (.invalidTensorDataType, .invalidTensorDataType), - (.failedToLoadModel, .failedToLoadModel), - (.failedToCreateInterpreter, .failedToCreateInterpreter), - (.failedToAllocateTensors, .failedToAllocateTensors), - (.allocateTensorsRequired, .allocateTensorsRequired), - (.invokeInterpreterRequired, .invokeInterpreterRequired): - return true - case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex), - .invalidTensorIndex(let rhsIndex, let rhsMaxIndex)): - return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex - case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount), - .invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)): - return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount - case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)): - return lhsIndex == rhsIndex - case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)): - return lhsMessage == rhsMessage - default: - return false - } - } -} -#endif // swift(>=4.2) diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/AppDelegate.swift b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/AppDelegate.swift index ffa90a06adb..45fd69716df 100644 --- a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/AppDelegate.swift +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/AppDelegate.swift @@ -14,11 +14,3 @@ final class AppDelegate: UIResponder, UIApplicationDelegate { return true } } - -// MARK: - Extensions - -#if !swift(>=4.2) -extension UIApplication { - typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey -} -#endif // !swift(>=4.2) diff --git a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Array+TensorFlowLite.swift b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Array+TensorFlowLite.swift index 56df1ce6597..e9fb026bb7b 100644 --- a/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Array+TensorFlowLite.swift +++ b/tensorflow/lite/experimental/swift/TestApps/TensorFlowLiteApp/TensorFlowLiteApp/Array+TensorFlowLite.swift @@ -11,12 +11,15 @@ extension Array { /// - Parameter unsafeData: The data containing the bytes to turn into an array. init?(unsafeData: Data) { guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } - let elements = unsafeData.withUnsafeBytes { - UnsafeBufferPointer( + #if swift(>=5.0) + self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } + #else + self = unsafeData.withUnsafeBytes { + .init(UnsafeBufferPointer( start: $0, count: unsafeData.count / MemoryLayout.stride - ) + )) } - self.init(elements) + #endif // swift(>=5.0) } } diff --git a/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift b/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift index e98da5f951e..1a9b898e480 100644 --- a/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift +++ b/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift @@ -287,18 +287,24 @@ private enum AddQuantizedModel { extension Array { /// Creates a new array from the bytes of the given unsafe data. /// + /// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit + /// with no indirection or reference-counting operations; otherwise, copying the raw bytes in + /// the `unsafeData`'s buffer to a new array returns an unsafe copy. /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of /// `MemoryLayout.stride`. /// - Parameter unsafeData: The data containing the bytes to turn into an array. init?(unsafeData: Data) { guard unsafeData.count % MemoryLayout.stride == 0 else { return nil } - let elements = unsafeData.withUnsafeBytes { - UnsafeBufferPointer( + #if swift(>=5.0) + self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } + #else + self = unsafeData.withUnsafeBytes { + .init(UnsafeBufferPointer( start: $0, count: unsafeData.count / MemoryLayout.stride - ) + )) } - self.init(elements) + #endif // swift(>=5.0) } } diff --git a/tensorflow/lite/experimental/swift/Tests/ModelTests.swift b/tensorflow/lite/experimental/swift/Tests/ModelTests.swift index 025db189060..c0fc15e7312 100644 --- a/tensorflow/lite/experimental/swift/Tests/ModelTests.swift +++ b/tensorflow/lite/experimental/swift/Tests/ModelTests.swift @@ -24,9 +24,9 @@ class ModelTests: XCTestCase { let bundle = Bundle(for: type(of: self)) guard let modelPath = bundle.path( - forResource: Constant.modelInfo.name, - ofType: Constant.modelInfo.extension) - else { + forResource: Constant.modelInfo.name, + ofType: Constant.modelInfo.extension + ) else { XCTFail("Failed to get the model file path.") return }