Updates the TensorFlow Lite Swift library to Swift 5.
PiperOrigin-RevId: 245493147
This commit is contained in:
parent
9aea6f1452
commit
37c101d175
@ -87,7 +87,6 @@ public final class Interpreter {
|
|||||||
/// - Throws: An error if the model was not ready because tensors were not allocated.
|
/// - Throws: An error if the model was not ready because tensors were not allocated.
|
||||||
public func invoke() throws {
|
public func invoke() throws {
|
||||||
guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else {
|
guard TFL_InterpreterInvoke(cInterpreter) == kTfLiteOk else {
|
||||||
// TODO(b/117510052): Determine which error to throw.
|
|
||||||
throw InterpreterError.allocateTensorsRequired
|
throw InterpreterError.allocateTensorsRequired
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -124,7 +123,6 @@ public final class Interpreter {
|
|||||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||||
var quantizationParameters: QuantizationParameters? = nil
|
var quantizationParameters: QuantizationParameters? = nil
|
||||||
if scale != 0.0 {
|
if scale != 0.0 {
|
||||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
|
||||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||||
}
|
}
|
||||||
let tensor = Tensor(
|
let tensor = Tensor(
|
||||||
@ -154,7 +152,6 @@ public final class Interpreter {
|
|||||||
let bytes = TFL_TensorData(cTensor),
|
let bytes = TFL_TensorData(cTensor),
|
||||||
let nameCString = TFL_TensorName(cTensor)
|
let nameCString = TFL_TensorName(cTensor)
|
||||||
else {
|
else {
|
||||||
// TODO(b/117510052): Determine which error to throw.
|
|
||||||
throw InterpreterError.invokeInterpreterRequired
|
throw InterpreterError.invokeInterpreterRequired
|
||||||
}
|
}
|
||||||
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
guard let dataType = TensorDataType(type: TFL_TensorType(cTensor)) else {
|
||||||
@ -172,7 +169,6 @@ public final class Interpreter {
|
|||||||
let zeroPoint = Int(cQuantizationParams.zero_point)
|
let zeroPoint = Int(cQuantizationParams.zero_point)
|
||||||
var quantizationParameters: QuantizationParameters? = nil
|
var quantizationParameters: QuantizationParameters? = nil
|
||||||
if scale != 0.0 {
|
if scale != 0.0 {
|
||||||
// TODO(b/117510052): Update this check once the TfLiteQuantizationParams struct has a mode.
|
|
||||||
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
quantizationParameters = QuantizationParameters(scale: scale, zeroPoint: zeroPoint)
|
||||||
}
|
}
|
||||||
let tensor = Tensor(
|
let tensor = Tensor(
|
||||||
@ -233,7 +229,13 @@ public final class Interpreter {
|
|||||||
throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
|
throw InterpreterError.invalidTensorDataCount(provided: data.count, required: byteCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if swift(>=5.0)
|
||||||
|
let status = data.withUnsafeBytes {
|
||||||
|
TFL_TensorCopyFromBuffer(cTensor, $0.baseAddress, data.count)
|
||||||
|
}
|
||||||
|
#else
|
||||||
let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) }
|
let status = data.withUnsafeBytes { TFL_TensorCopyFromBuffer(cTensor, $0, data.count) }
|
||||||
|
#endif // swift(>=5.0)
|
||||||
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
||||||
return try input(at: index)
|
return try input(at: index)
|
||||||
}
|
}
|
||||||
|
@ -68,32 +68,4 @@ extension InterpreterError: CustomStringConvertible {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if swift(>=4.2)
|
|
||||||
extension InterpreterError: Equatable {}
|
extension InterpreterError: Equatable {}
|
||||||
#else
|
|
||||||
extension InterpreterError: Equatable {
|
|
||||||
public static func == (lhs: InterpreterError, rhs: InterpreterError) -> Bool {
|
|
||||||
switch (lhs, rhs) {
|
|
||||||
case (.invalidTensorDataType, .invalidTensorDataType),
|
|
||||||
(.failedToLoadModel, .failedToLoadModel),
|
|
||||||
(.failedToCreateInterpreter, .failedToCreateInterpreter),
|
|
||||||
(.failedToAllocateTensors, .failedToAllocateTensors),
|
|
||||||
(.allocateTensorsRequired, .allocateTensorsRequired),
|
|
||||||
(.invokeInterpreterRequired, .invokeInterpreterRequired):
|
|
||||||
return true
|
|
||||||
case (.invalidTensorIndex(let lhsIndex, let lhsMaxIndex),
|
|
||||||
.invalidTensorIndex(let rhsIndex, let rhsMaxIndex)):
|
|
||||||
return lhsIndex == rhsIndex && lhsMaxIndex == rhsMaxIndex
|
|
||||||
case (.invalidTensorDataCount(let lhsProvidedCount, let lhsRequiredCount),
|
|
||||||
.invalidTensorDataCount(let rhsProvidedCount, let rhsRequiredCount)):
|
|
||||||
return lhsProvidedCount == rhsProvidedCount && lhsRequiredCount == rhsRequiredCount
|
|
||||||
case (.failedToResizeInputTensor(let lhsIndex), .failedToResizeInputTensor(let rhsIndex)):
|
|
||||||
return lhsIndex == rhsIndex
|
|
||||||
case (.tensorFlowLiteError(let lhsMessage), .tensorFlowLiteError(let rhsMessage)):
|
|
||||||
return lhsMessage == rhsMessage
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif // swift(>=4.2)
|
|
||||||
|
@ -14,11 +14,3 @@ final class AppDelegate: UIResponder, UIApplicationDelegate {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Extensions
|
|
||||||
|
|
||||||
#if !swift(>=4.2)
|
|
||||||
extension UIApplication {
|
|
||||||
typealias LaunchOptionsKey = UIApplicationLaunchOptionsKey
|
|
||||||
}
|
|
||||||
#endif // !swift(>=4.2)
|
|
||||||
|
@ -11,12 +11,15 @@ extension Array {
|
|||||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||||
init?(unsafeData: Data) {
|
init?(unsafeData: Data) {
|
||||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||||
let elements = unsafeData.withUnsafeBytes {
|
#if swift(>=5.0)
|
||||||
UnsafeBufferPointer<Element>(
|
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
|
||||||
|
#else
|
||||||
|
self = unsafeData.withUnsafeBytes {
|
||||||
|
.init(UnsafeBufferPointer<Element>(
|
||||||
start: $0,
|
start: $0,
|
||||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||||
)
|
))
|
||||||
}
|
}
|
||||||
self.init(elements)
|
#endif // swift(>=5.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -287,18 +287,24 @@ private enum AddQuantizedModel {
|
|||||||
extension Array {
|
extension Array {
|
||||||
/// Creates a new array from the bytes of the given unsafe data.
|
/// Creates a new array from the bytes of the given unsafe data.
|
||||||
///
|
///
|
||||||
|
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
|
||||||
|
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
|
||||||
|
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
|
||||||
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
|
||||||
/// `MemoryLayout<Element>.stride`.
|
/// `MemoryLayout<Element>.stride`.
|
||||||
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
|
||||||
init?(unsafeData: Data) {
|
init?(unsafeData: Data) {
|
||||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||||
let elements = unsafeData.withUnsafeBytes {
|
#if swift(>=5.0)
|
||||||
UnsafeBufferPointer<Element>(
|
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
|
||||||
|
#else
|
||||||
|
self = unsafeData.withUnsafeBytes {
|
||||||
|
.init(UnsafeBufferPointer<Element>(
|
||||||
start: $0,
|
start: $0,
|
||||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||||
)
|
))
|
||||||
}
|
}
|
||||||
self.init(elements)
|
#endif // swift(>=5.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,8 +25,8 @@ class ModelTests: XCTestCase {
|
|||||||
let bundle = Bundle(for: type(of: self))
|
let bundle = Bundle(for: type(of: self))
|
||||||
guard let modelPath = bundle.path(
|
guard let modelPath = bundle.path(
|
||||||
forResource: Constant.modelInfo.name,
|
forResource: Constant.modelInfo.name,
|
||||||
ofType: Constant.modelInfo.extension)
|
ofType: Constant.modelInfo.extension
|
||||||
else {
|
) else {
|
||||||
XCTFail("Failed to get the model file path.")
|
XCTFail("Failed to get the model file path.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user