Updates the TensorFlow Lite Swift library to Swift 5.
PiperOrigin-RevId: 245493147
This commit is contained in:
parent
9aea6f1452
commit
37c101d175
@ -87,7 +87,6 @@ public final class Interpreter {
|
||||
/// - Throws: An error if the model was not ready because tensors were not allocated.
|
||||
public func invoke() throws {
|
||||
guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else {
|
||||
// TODO(b/117510052): Determine which error to throw.
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
}
|
||||
@ -104,8 +103,8 @@ public final class Interpreter {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetInputTensor(cInterpreter, Int32(index)),
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
else {
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
@ -124,7 +123,6 @@ public final class Interpreter {
|
||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||
var quantizationParameters: QuantizationParameters? = nil
|
||||
if scale != 0.0 {
|
||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||
}
|
||||
let tensor = Tensor(
|
||||
@ -151,10 +149,9 @@ public final class Interpreter {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard let cTensor = TFL_InterpreterGetOutputTensor(cInterpreter, Int32(index)),
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
let bytes = TFL_TensorData(cTensor),
|
||||
let nameCString = TFL_TensorName(cTensor)
|
||||
else {
|
||||
// TODO(b/117510052): Determine which error to throw.
|
||||
throw InterpreterError.invokeInterpreterRequired
|
||||
}
|
||||
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||
@ -172,7 +169,6 @@ public final class Interpreter {
|
||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||
var quantizationParameters: QuantizationParameters? = nil
|
||||
if scale != 0.0 {
|
||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||
}
|
||||
let tensor = Tensor(
|
||||
@ -200,11 +196,11 @@ public final class Interpreter {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
}
|
||||
guard TFL_InterpreterResizeInputTensor(
|
||||
cInterpreter,
|
||||
Int32(index),
|
||||
shape.int32Dimensions,
|
||||
Int32(shape.rank)
|
||||
) == kTfLiteOk
|
||||
cInterpreter,
|
||||
Int32(index),
|
||||
shape.int32Dimensions,
|
||||
Int32(shape.rank)
|
||||
) == kTfLiteOk
|
||||
else {
|
||||
throw InterpreterError.failedToResizeInputTensor(index: index)
|
||||
}
|
||||
@ -233,7 +229,13 @@ public final class Interpreter {
|
||||
throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
|
||||
}
|
||||
|
||||
#if swift(>=5.0)
|
||||
let status = data.withUnsafeBytes {
|
||||
TFL_TensorCopyFromBuffer(cTensor, $0.baseAddress, data.count)
|
||||
}
|
||||
#else
|
||||
let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) }
|
||||
#endif // swift(>=5.0)
|
||||
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
||||
return try input(at: index)
|
||||
}
|
||||
|
@ -68,32 +68,4 @@ extension InterpreterError: CustomStringConvertible {
|
||||
}
|
||||
}
|
||||
|
||||
#if swift(>=4.2)
|
||||
extension InterpreterError: Equatable {}
|
||||
#else
|
||||
extension InterpreterError: Equatable {
|
||||
public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool {
|
||||
switch (lhs, rhs) {
|
||||
case (.invalidTensorDataType, .invalidTensorDataType),
|
||||
(.failedToLoadModel, .failedToLoadModel),
|
||||
(.failedToCreateInterpreter, .failedToCreateInterpreter),
|
||||
(.failedToAllocateTensors, .failedToAllocateTensors),
|
||||
(.allocateTensorsRequired, .allocateTensorsRequired),
|
||||
(.invokeInterpreterRequired, .invokeInterpreterRequired):
|
||||
return true
|
||||
case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex),
|
||||
.invalidTensorIndex(let rhsIndex, let rhsMaxIndex)):
|
||||
return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex
|
||||
case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount),
|
||||
.invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)):
|
||||
return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount
|
||||
case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)):
|
||||
return lhsIndex == rhsIndex
|
||||
case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)):
|
||||
return lhsMessage == rhsMessage
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // swift(>=4.2)
|
||||
|
@ -14,11 +14,3 @@ final class AppDelegate: UIResponder, UIApplicationDelegate {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
#if !swift(>=4.2)
|
||||
extension UIApplication {
|
||||
typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey
|
||||
}
|
||||
#endif // !swift(>=4.2)
|
||||
|
@ -11,12 +11,15 @@ extension Array {
|
||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||
init?(unsafeData: Data) {
|
||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||
let elements = unsafeData.withUnsafeBytes {
|
||||
UnsafeBufferPointer<Element>(
|
||||
#if swift(>=5.0)
|
||||
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
|
||||
#else
|
||||
self = unsafeData.withUnsafeBytes {
|
||||
.init(UnsafeBufferPointer<Element>(
|
||||
start: $0,
|
||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||
)
|
||||
))
|
||||
}
|
||||
self.init(elements)
|
||||
#endif // swift(>=5.0)
|
||||
}
|
||||
}
|
||||
|
@ -287,18 +287,24 @@ private enum AddQuantizedModel {
|
||||
extension Array {
|
||||
/// Creates a new array from the bytes of the given unsafe data.
|
||||
///
|
||||
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
|
||||
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
|
||||
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
|
||||
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||
/// `MemoryLayout<Element>.stride`.
|
||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||
init?(unsafeData: Data) {
|
||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||
let elements = unsafeData.withUnsafeBytes {
|
||||
UnsafeBufferPointer<Element>(
|
||||
#if swift(>=5.0)
|
||||
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
|
||||
#else
|
||||
self = unsafeData.withUnsafeBytes {
|
||||
.init(UnsafeBufferPointer<Element>(
|
||||
start: $0,
|
||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||
)
|
||||
))
|
||||
}
|
||||
self.init(elements)
|
||||
#endif // swift(>=5.0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,9 +24,9 @@ class ModelTests: XCTestCase {
|
||||
|
||||
let bundle = Bundle(for: type(of: self))
|
||||
guard let modelPath = bundle.path(
|
||||
forResource: Constant.modelInfo.name,
|
||||
ofType: Constant.modelInfo.extension)
|
||||
else {
|
||||
forResource: Constant.modelInfo.name,
|
||||
ofType: Constant.modelInfo.extension
|
||||
) else {
|
||||
XCTFail("Failed to get the model file path.")
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user