Updates the TensorFlow Lite Swift library to Swift 5.

PiperOrigin-RevId: 245493147
This commit is contained in:
A. Unique TensorFlower 2019-04-26 14:58:53 -07:00 committed by TensorFlower Gardener
parent 9aea6f1452
commit 37c101d175
6 changed files with 35 additions and 60 deletions

View File

@ -87,7 +87,6 @@ public final class Interpreter {
/// - Throws: An error if the model was not ready because tensors were not allocated. /// - Throws: An error if the model was not ready because tensors were not allocated.
public func invoke() throws { public func invoke() throws {
guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else { guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else {
// TODO(b/117510052): Determine which error to throw.
throw InterpreterError.allocateTensorsRequired throw InterpreterError.allocateTensorsRequired
} }
} }
@ -104,8 +103,8 @@ public final class Interpreter {
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
} }
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)), guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)),
let bytes = TFL_TensorData(cTensor), let bytes = TFL_TensorData(cTensor),
let nameCString = TFL_TensorName(cTensor) let nameCString = TFL_TensorName(cTensor)
else { else {
throw InterpreterError.allocateTensorsRequired throw InterpreterError.allocateTensorsRequired
} }
@ -124,7 +123,6 @@ public final class Interpreter {
let zeroPoint = Int(cQuantizationParams.zero_point) let zeroPoint = Int(cQuantizationParams.zero_point)
var quantizationParameters: QuantizationParameters? = nil var quantizationParameters: QuantizationParameters? = nil
if scale != 0.0 { if scale != 0.0 {
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint) quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
} }
let tensor = Tensor( let tensor = Tensor(
@ -151,10 +149,9 @@ public final class Interpreter {
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
} }
guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)), guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)),
let bytes = TFL_TensorData(cTensor), let bytes = TFL_TensorData(cTensor),
let nameCString = TFL_TensorName(cTensor) let nameCString = TFL_TensorName(cTensor)
else { else {
// TODO(b/117510052): Determine which error to throw.
throw InterpreterError.invokeInterpreterRequired throw InterpreterError.invokeInterpreterRequired
} }
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else { guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
@ -172,7 +169,6 @@ public final class Interpreter {
let zeroPoint = Int(cQuantizationParams.zero_point) let zeroPoint = Int(cQuantizationParams.zero_point)
var quantizationParameters: QuantizationParameters? = nil var quantizationParameters: QuantizationParameters? = nil
if scale != 0.0 { if scale != 0.0 {
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint) quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
} }
let tensor = Tensor( let tensor = Tensor(
@ -200,11 +196,11 @@ public final class Interpreter {
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
} }
guard TFL_InterpreterResizeInputTensor( guard TFL_InterpreterResizeInputTensor(
cInterpreter, cInterpreter,
Int32(index), Int32(index),
shape.int32Dimensions, shape.int32Dimensions,
Int32(shape.rank) Int32(shape.rank)
) == kTfLiteOk ) == kTfLiteOk
else { else {
throw InterpreterError.failedToResizeInputTensor(index: index) throw InterpreterError.failedToResizeInputTensor(index: index)
} }
@ -233,7 +229,13 @@ public final class Interpreter {
throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount) throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
} }
#if swift(>=5.0)
let status = data.withUnsafeBytes {
TFL_TensorCopyFromBuffer(cTensor, $0.baseAddress, data.count)
}
#else
let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) } let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) }
#endif // swift(>=5.0)
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor } guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
return try input(at: index) return try input(at: index)
} }

View File

@ -68,32 +68,4 @@ extension InterpreterError: CustomStringConvertible {
} }
} }
#if swift(>=4.2)
extension InterpreterError: Equatable {} extension InterpreterError: Equatable {}
#else
extension InterpreterError: Equatable {
public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool {
switch (lhs, rhs) {
case (.invalidTensorDataType, .invalidTensorDataType),
(.failedToLoadModel, .failedToLoadModel),
(.failedToCreateInterpreter, .failedToCreateInterpreter),
(.failedToAllocateTensors, .failedToAllocateTensors),
(.allocateTensorsRequired, .allocateTensorsRequired),
(.invokeInterpreterRequired, .invokeInterpreterRequired):
return true
case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex),
.invalidTensorIndex(let rhsIndex, let rhsMaxIndex)):
return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex
case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount),
.invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)):
return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount
case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)):
return lhsIndex == rhsIndex
case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)):
return lhsMessage == rhsMessage
default:
return false
}
}
}
#endif // swift(>=4.2)

View File

@ -14,11 +14,3 @@ final class AppDelegate: UIResponder, UIApplicationDelegate {
return true return true
} }
} }
// MARK: - Extensions
#if !swift(>=4.2)
extension UIApplication {
typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey
}
#endif // !swift(>=4.2)

View File

@ -11,12 +11,15 @@ extension Array {
/// - Parameter unsafeData: The data containing the bytes to turn into an array. /// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) { init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil } guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
let elements = unsafeData.withUnsafeBytes { #if swift(>=5.0)
UnsafeBufferPointer<Element>( self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0, start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride count: unsafeData.count / MemoryLayout<Element>.stride
) ))
} }
self.init(elements) #endif // swift(>=5.0)
} }
} }

View File

@ -287,18 +287,24 @@ private enum AddQuantizedModel {
extension Array { extension Array {
/// Creates a new array from the bytes of the given unsafe data. /// Creates a new array from the bytes of the given unsafe data.
/// ///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of /// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`. /// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array. /// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) { init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil } guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
let elements = unsafeData.withUnsafeBytes { #if swift(>=5.0)
UnsafeBufferPointer<Element>( self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0, start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride count: unsafeData.count / MemoryLayout<Element>.stride
) ))
} }
self.init(elements) #endif // swift(>=5.0)
} }
} }

View File

@ -24,9 +24,9 @@ class ModelTests: XCTestCase {
let bundle = Bundle(for: type(of: self)) let bundle = Bundle(for: type(of: self))
guard let modelPath = bundle.path( guard let modelPath = bundle.path(
forResource: Constant.modelInfo.name, forResource: Constant.modelInfo.name,
ofType: Constant.modelInfo.extension) ofType: Constant.modelInfo.extension
else { ) else {
XCTFail("Failed to get the model file path.") XCTFail("Failed to get the model file path.")
return return
} }