Updates InterpreterOptions, TensorShape, and TensorDataType to be nested types.
PiperOrigin-RevId: 265160862
This commit is contained in:
parent
2fce6eaf6d
commit
c3a9eb7c28
@ -16,9 +16,9 @@ import TensorFlowLiteC
|
||||
|
||||
/// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations.
|
||||
public protocol Delegate: class {
|
||||
/// `TFL_Delegate` C pointer type.
|
||||
typealias CDelegate = OpaquePointer
|
||||
/// The `TfLiteDelegate` C pointer type.
|
||||
typealias CDelegate = UnsafeMutablePointer<TfLiteDelegate>
|
||||
|
||||
/// Delegate that performs model computations.
|
||||
var cDelegate: CDelegate? { get }
|
||||
/// The delegate that performs model computations.
|
||||
var cDelegate: CDelegate { get }
|
||||
}
|
||||
|
@ -17,32 +17,36 @@ import TensorFlowLiteC
|
||||
|
||||
/// A TensorFlow Lite interpreter that performs inference from a given model.
|
||||
public final class Interpreter {
|
||||
/// `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
|
||||
private typealias CInterpreter = OpaquePointer
|
||||
/// The configuration options for the `Interpreter`.
|
||||
public let options: Options?
|
||||
|
||||
/// Total number of input tensors associated with the model.
|
||||
/// The total number of input tensors associated with the model.
|
||||
public var inputTensorCount: Int {
|
||||
return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter))
|
||||
}
|
||||
|
||||
/// Total number of output tensors associated with the model.
|
||||
/// The total number of output tensors associated with the model.
|
||||
public var outputTensorCount: Int {
|
||||
return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter))
|
||||
}
|
||||
|
||||
/// Underlying `TfLiteInterpreter` C pointer.
|
||||
/// The `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TfLiteInterpreter>`.
|
||||
private typealias CInterpreter = OpaquePointer
|
||||
|
||||
/// The underlying `TfLiteInterpreter` C pointer.
|
||||
private var cInterpreter: CInterpreter?
|
||||
|
||||
/// Creates a new model interpreter instance.
|
||||
/// Creates a new instance with the given values.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - modelPath: Local file path to a TensorFlow Lite model.
|
||||
/// - options: Custom configurations for the interpreter. Default is `nil` indicating that the
|
||||
/// interpreter will determine the configuration options.
|
||||
/// - options: Custom configuration options for the interpreter. Default is `nil` indicating
|
||||
/// that the interpreter will determine the configuration options.
|
||||
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
||||
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
|
||||
public init(modelPath: String, options: Options? = nil) throws {
|
||||
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
||||
|
||||
self.options = options
|
||||
let cInterpreterOptions: OpaquePointer? = try options.map { options in
|
||||
guard let cOptions = TfLiteInterpreterOptionsCreate() else {
|
||||
throw InterpreterError.failedToCreateInterpreter
|
||||
@ -105,14 +109,14 @@ public final class Interpreter {
|
||||
else {
|
||||
throw InterpreterError.allocateTensorsRequired
|
||||
}
|
||||
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else {
|
||||
guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
|
||||
throw InterpreterError.invalidTensorDataType
|
||||
}
|
||||
|
||||
let name = String(cString: nameCString)
|
||||
let rank = TfLiteTensorNumDims(cTensor)
|
||||
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
|
||||
let shape = TensorShape(dimensions)
|
||||
let shape = Tensor.Shape(dimensions)
|
||||
let byteCount = TfLiteTensorByteSize(cTensor)
|
||||
let data = Data(bytes: bytes, count: byteCount)
|
||||
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
|
||||
@ -151,14 +155,14 @@ public final class Interpreter {
|
||||
else {
|
||||
throw InterpreterError.invokeInterpreterRequired
|
||||
}
|
||||
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else {
|
||||
guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
|
||||
throw InterpreterError.invalidTensorDataType
|
||||
}
|
||||
|
||||
let name = String(cString: nameCString)
|
||||
let rank = TfLiteTensorNumDims(cTensor)
|
||||
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
|
||||
let shape = TensorShape(dimensions)
|
||||
let shape = Tensor.Shape(dimensions)
|
||||
let byteCount = TfLiteTensorByteSize(cTensor)
|
||||
let data = Data(bytes: bytes, count: byteCount)
|
||||
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
|
||||
@ -187,7 +191,7 @@ public final class Interpreter {
|
||||
/// - index: The index for the input tensor.
|
||||
/// - shape: The shape that the input tensor should be resized to.
|
||||
/// - Throws: An error if the input tensor at the given index could not be resized.
|
||||
public func resizeInput(at index: Int, to shape: TensorShape) throws {
|
||||
public func resizeInput(at index: Int, to shape: Tensor.Shape) throws {
|
||||
let maxIndex = inputTensorCount - 1
|
||||
guard case 0...maxIndex = index else {
|
||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||
@ -237,7 +241,7 @@ public final class Interpreter {
|
||||
return try input(at: index)
|
||||
}
|
||||
|
||||
/// Allocates memory for all input tensors based on their `TensorShape`s.
|
||||
/// Allocates memory for all input tensors based on their `Tensor.Shape`s.
|
||||
///
|
||||
/// - Note: This is a relatively expensive operation and should only be called after creating the
|
||||
/// interpreter and/or resizing any input tensors.
|
||||
@ -249,7 +253,17 @@ public final class Interpreter {
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
extension Interpreter {
|
||||
/// Options for configuring the `Interpreter`.
|
||||
public struct Options: Equatable, Hashable {
|
||||
/// The maximum number of CPU threads that the interpreter should run on. Default is `nil`
|
||||
/// indicating that the `Interpreter` will decide the number of threads to use.
|
||||
public var threadCount: Int? = nil
|
||||
|
||||
/// Creates a new instance with the default values.
|
||||
public init() {}
|
||||
}
|
||||
}
|
||||
|
||||
extension String {
|
||||
/// Returns a new `String` initialized by using the given format C array as a template into which
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import Foundation
|
||||
|
||||
/// TensorFlow Lite interpreter errors.
|
||||
/// Errors thrown by the TensorFlow Lite `Interpreter`.
|
||||
public enum InterpreterError: Error, Equatable, Hashable {
|
||||
case invalidTensorIndex(index: Int, maxIndex: Int)
|
||||
case invalidTensorDataCount(provided: Int, required: Int)
|
||||
@ -29,10 +29,8 @@ public enum InterpreterError: Error, Equatable, Hashable {
|
||||
case tensorFlowLiteError(String)
|
||||
}
|
||||
|
||||
// MARK: - Extensions
|
||||
|
||||
extension InterpreterError: LocalizedError {
|
||||
/// Localized description of the interpreter error.
|
||||
/// A localized description of the interpreter error.
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .invalidTensorIndex(let index, let maxIndex):
|
||||
@ -62,6 +60,6 @@ extension InterpreterError: LocalizedError {
|
||||
}
|
||||
|
||||
extension InterpreterError: CustomStringConvertible {
|
||||
/// Textual representation of the TensorFlow Lite interpreter error.
|
||||
/// A textual representation of the TensorFlow Lite interpreter error.
|
||||
public var description: String { return errorDescription ?? "Unknown error." }
|
||||
}
|
||||
|
@ -1,23 +0,0 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/// Custom configuration options for a TensorFlow Lite `Interpreter`.
|
||||
public struct InterpreterOptions: Equatable {
|
||||
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` indicating
|
||||
/// that the `Interpreter` will decide the number of threads to use.
|
||||
public var threadCount: Int? = nil
|
||||
|
||||
/// Creates a new instance of interpreter options.
|
||||
public init() {}
|
||||
}
|
@ -14,15 +14,15 @@
|
||||
|
||||
import TensorFlowLiteC
|
||||
|
||||
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
|
||||
/// A TensorFlow Lite model used by the `Interpreter` to perform inference.
|
||||
final class Model {
|
||||
/// `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`.
|
||||
/// The `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`.
|
||||
typealias CModel = OpaquePointer
|
||||
|
||||
/// Underlying `TfLiteModel` C pointer.
|
||||
/// The underlying `TfLiteModel` C pointer.
|
||||
let cModel: CModel?
|
||||
|
||||
/// Creates a new model instance.
|
||||
/// Creates a new instance with the given `filePath`.
|
||||
///
|
||||
/// - Precondition: Initialization can fail if the given `filePath` is invalid.
|
||||
/// - Parameters:
|
||||
|
@ -16,11 +16,11 @@
|
||||
/// be mapped to float values using the following conversion:
|
||||
/// `realValue = scale * (quantizedValue - zeroPoint)`.
|
||||
public struct QuantizationParameters: Equatable, Hashable {
|
||||
/// Difference between real values corresponding to consecutive quantized values differing by 1.
|
||||
/// For example, the range of quantized values for `UInt8` data type is [0, 255].
|
||||
/// The difference between real values corresponding to consecutive quantized values differing by
|
||||
/// 1. For example, the range of quantized values for `UInt8` data type is [0, 255].
|
||||
public let scale: Float
|
||||
|
||||
/// Quantized value that corresponds to the real 0 value.
|
||||
/// The quantized value that corresponds to the real 0 value.
|
||||
public let zeroPoint: Int
|
||||
|
||||
/// Creates a new quantization parameters instance.
|
||||
|
@ -17,19 +17,19 @@ import TensorFlowLiteC
|
||||
|
||||
/// An input or output tensor in a TensorFlow Lite graph.
|
||||
public struct Tensor: Equatable, Hashable {
|
||||
/// Name of the tensor.
|
||||
/// The name of the tensor.
|
||||
public let name: String
|
||||
|
||||
/// Data type of the tensor.
|
||||
public let dataType: TensorDataType
|
||||
/// The data type of the tensor.
|
||||
public let dataType: DataType
|
||||
|
||||
/// Shape of the tensor.
|
||||
public let shape: TensorShape
|
||||
/// The shape of the tensor.
|
||||
public let shape: Shape
|
||||
|
||||
/// Data in the input or output tensor.
|
||||
/// The data in the input or output tensor.
|
||||
public let data: Data
|
||||
|
||||
/// Quantization parameters for the tensor if using a quantized model.
|
||||
/// The quantization parameters for the tensor if using a quantized model.
|
||||
public let quantizationParameters: QuantizationParameters?
|
||||
|
||||
/// Creates a new input or output tensor instance.
|
||||
@ -43,8 +43,8 @@ public struct Tensor: Equatable, Hashable {
|
||||
/// Default is `nil`.
|
||||
init(
|
||||
name: String,
|
||||
dataType: TensorDataType,
|
||||
shape: TensorShape,
|
||||
dataType: DataType,
|
||||
shape: Shape,
|
||||
data: Data,
|
||||
quantizationParameters: QuantizationParameters? = nil
|
||||
) {
|
||||
@ -56,83 +56,86 @@ public struct Tensor: Equatable, Hashable {
|
||||
}
|
||||
}
|
||||
|
||||
/// Supported TensorFlow Lite tensor data types.
|
||||
public enum TensorDataType: Equatable, Hashable {
|
||||
/// Boolean.
|
||||
case bool
|
||||
/// 8-bit unsigned integer.
|
||||
case uInt8
|
||||
/// 16-bit signed integer.
|
||||
case int16
|
||||
/// 32-bit signed integer.
|
||||
case int32
|
||||
/// 64-bit signed integer.
|
||||
case int64
|
||||
/// 16-bit half precision floating point.
|
||||
case float16
|
||||
/// 32-bit single precision floating point.
|
||||
case float32
|
||||
extension Tensor {
|
||||
/// The supported `Tensor` data types.
|
||||
public enum DataType: Equatable, Hashable {
|
||||
/// A boolean.
|
||||
case bool
|
||||
/// An 8-bit unsigned integer.
|
||||
case uInt8
|
||||
/// A 16-bit signed integer.
|
||||
case int16
|
||||
/// A 32-bit signed integer.
|
||||
case int32
|
||||
/// A 64-bit signed integer.
|
||||
case int64
|
||||
/// A 16-bit half precision floating point.
|
||||
case float16
|
||||
/// A 32-bit single precision floating point.
|
||||
case float32
|
||||
|
||||
/// Creates a new tensor data type from the given `TfLiteType` or `nil` if the data type is
|
||||
/// unsupported or could not be determined because there was an error.
|
||||
///
|
||||
/// - Parameter type: A data type supported by a tensor.
|
||||
init?(type: TfLiteType) {
|
||||
switch type {
|
||||
case kTfLiteBool:
|
||||
self = .bool
|
||||
case kTfLiteUInt8:
|
||||
self = .uInt8
|
||||
case kTfLiteInt16:
|
||||
self = .int16
|
||||
case kTfLiteInt32:
|
||||
self = .int32
|
||||
case kTfLiteInt64:
|
||||
self = .int64
|
||||
case kTfLiteFloat16:
|
||||
self = .float16
|
||||
case kTfLiteFloat32:
|
||||
self = .float32
|
||||
case kTfLiteNoType:
|
||||
fallthrough
|
||||
default:
|
||||
return nil
|
||||
/// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported
|
||||
/// or could not be determined because there was an error.
|
||||
///
|
||||
/// - Parameter type: Data type supported by a tensor.
|
||||
init?(type: TfLiteType) {
|
||||
switch type {
|
||||
case kTfLiteBool:
|
||||
self = .bool
|
||||
case kTfLiteUInt8:
|
||||
self = .uInt8
|
||||
case kTfLiteInt16:
|
||||
self = .int16
|
||||
case kTfLiteInt32:
|
||||
self = .int32
|
||||
case kTfLiteInt64:
|
||||
self = .int64
|
||||
case kTfLiteFloat16:
|
||||
self = .float16
|
||||
case kTfLiteFloat32:
|
||||
self = .float32
|
||||
case kTfLiteNoType:
|
||||
fallthrough
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The shape of a TensorFlow Lite tensor.
|
||||
public struct TensorShape: Equatable, Hashable {
|
||||
extension Tensor {
|
||||
/// The shape of a `Tensor`.
|
||||
public struct Shape: Equatable, Hashable {
|
||||
/// The number of dimensions of the tensor.
|
||||
public let rank: Int
|
||||
|
||||
/// The number of dimensions of the tensor.
|
||||
public let rank: Int
|
||||
/// An array of dimensions for the tensor.
|
||||
public let dimensions: [Int]
|
||||
|
||||
/// Array of dimensions for the tensor.
|
||||
public let dimensions: [Int]
|
||||
/// An array of `Int32` dimensions for the tensor.
|
||||
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
|
||||
|
||||
/// Array of `Int32` dimensions for the tensor.
|
||||
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
|
||||
/// Creates a new instance with the given array of dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - dimensions: Dimensions for the tensor.
|
||||
public init(_ dimensions: [Int]) {
|
||||
self.rank = dimensions.count
|
||||
self.dimensions = dimensions
|
||||
}
|
||||
|
||||
/// Creates a new tensor shape instance with the given array of dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - dimensions: Dimensions for the tensor.
|
||||
public init(_ dimensions: [Int]) {
|
||||
self.rank = dimensions.count
|
||||
self.dimensions = dimensions
|
||||
}
|
||||
|
||||
/// Creates a new tensor shape instance with the given elements representing the dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - elements: Dimensions for the tensor.
|
||||
public init(_ elements: Int...) {
|
||||
self.init(elements)
|
||||
/// Creates a new instance with the given elements representing the dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - elements: Dimensions for the tensor.
|
||||
public init(_ elements: Int...) {
|
||||
self.init(elements)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extension TensorShape: ExpressibleByArrayLiteral {
|
||||
/// Creates a new tensor shape instance with the given array literal representing the dimensions.
|
||||
extension Tensor.Shape: ExpressibleByArrayLiteral {
|
||||
/// Creates a new instance with the given array literal representing the dimensions.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - arrayLiteral: Dimensions for the tensor.
|
||||
|
@ -12,18 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import class TensorFlowLite.Interpreter
|
||||
import struct TensorFlowLite.InterpreterOptions
|
||||
import struct TensorFlowLite.Tensor
|
||||
import struct TensorFlowLite.TensorShape
|
||||
import enum TensorFlowLite.Runtime
|
||||
import TensorFlowLite
|
||||
import UIKit
|
||||
|
||||
class ViewController: UIViewController {
|
||||
|
||||
// MARK: - Properties
|
||||
|
||||
/// TensorFlowLite interpreter object for performing inference from a given model.
|
||||
/// TensorFlow Lite interpreter object for performing inference from a given model.
|
||||
private var interpreter: Interpreter?
|
||||
|
||||
/// Serial dispatch queue for managing `Interpreter` calls.
|
||||
@ -122,7 +118,7 @@ class ViewController: UIViewController {
|
||||
private func setUpInterpreter(withModelPath modelPath: String) {
|
||||
interpreterQueue.async {
|
||||
do {
|
||||
var options = InterpreterOptions()
|
||||
var options = Interpreter.Options()
|
||||
options.threadCount = 2
|
||||
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
|
||||
} catch let error {
|
||||
@ -211,7 +207,7 @@ class ViewController: UIViewController {
|
||||
return
|
||||
}
|
||||
do {
|
||||
let shape = TensorShape(2)
|
||||
let shape = Tensor.Shape(2)
|
||||
try (0..<interpreter.inputTensorCount).forEach { index in
|
||||
try interpreter.resizeInput(at: index, to: shape)
|
||||
}
|
||||
|
@ -1,43 +0,0 @@
|
||||
// Copyright 2018 Google Inc. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at:
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
class InterpreterOptionsTests: XCTestCase {
|
||||
|
||||
func testInterpreterOptions_InitWithDefaultValues() {
|
||||
let options = InterpreterOptions()
|
||||
XCTAssertNil(options.threadCount)
|
||||
}
|
||||
|
||||
func testInterpreterOptions_InitWithCustomValues() {
|
||||
var options = InterpreterOptions()
|
||||
options.threadCount = 2
|
||||
XCTAssertEqual(options.threadCount, 2)
|
||||
}
|
||||
|
||||
func testInterpreterOptions_Equatable() {
|
||||
var options1 = InterpreterOptions()
|
||||
var options2 = InterpreterOptions()
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options1.threadCount = 2
|
||||
options2.threadCount = 2
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options2.threadCount = 3
|
||||
XCTAssertNotEqual(options1, options2)
|
||||
}
|
||||
}
|
@ -12,9 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
@testable import TensorFlowLite
|
||||
|
||||
class InterpreterTests: XCTestCase {
|
||||
|
||||
var interpreter: Interpreter!
|
||||
@ -31,55 +32,56 @@ class InterpreterTests: XCTestCase {
|
||||
super.tearDown()
|
||||
}
|
||||
|
||||
func testInterpreter_InitWithModelPath() {
|
||||
func testInitWithModelPath() {
|
||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
|
||||
}
|
||||
|
||||
func testInterpreter_Init_ThrowsFailedToLoadModel() {
|
||||
func testInit_ThrowsFailedToLoadModel() {
|
||||
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InitWithModelPathAndOptions() {
|
||||
var options = InterpreterOptions()
|
||||
func testInitWithModelPathAndOptions() throws {
|
||||
var options = Interpreter.Options()
|
||||
options.threadCount = 2
|
||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options))
|
||||
let interpreter = try Interpreter(modelPath: AddModel.path, options: options)
|
||||
XCTAssertNotNil(interpreter.options)
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorCount() {
|
||||
func testInputTensorCount() {
|
||||
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorCount() {
|
||||
func testOutputTensorCount() {
|
||||
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
|
||||
}
|
||||
|
||||
func testInterpreter_Invoke() throws {
|
||||
func testInvoke() throws {
|
||||
try interpreter.allocateTensors()
|
||||
XCTAssertNoThrow(try interpreter.invoke())
|
||||
}
|
||||
|
||||
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
|
||||
func testInvoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
|
||||
XCTAssertThrowsError(try interpreter.invoke()) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex() throws {
|
||||
func testInputTensorAtIndex() throws {
|
||||
try setUpAddModelInputTensor()
|
||||
let inputTensor = try interpreter.input(at: AddModel.validIndex)
|
||||
XCTAssertEqual(inputTensor, AddModel.inputTensor)
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws {
|
||||
func testInputTensorAtIndex_QuantizedModel() throws {
|
||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||
try setUpAddQuantizedModelInputTensor()
|
||||
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
|
||||
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
func testInputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
try interpreter.allocateTensors()
|
||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
|
||||
let maxIndex = AddModel.inputTensorCount - 1
|
||||
@ -90,13 +92,13 @@ class InterpreterTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() {
|
||||
func testInputTensorAtIndex_ThrowsAllocateTensorsRequired() {
|
||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex() throws {
|
||||
func testOutputTensorAtIndex() throws {
|
||||
try setUpAddModelInputTensor()
|
||||
try interpreter.invoke()
|
||||
let outputTensor = try interpreter.output(at: AddModel.validIndex)
|
||||
@ -105,7 +107,7 @@ class InterpreterTests: XCTestCase {
|
||||
XCTAssertEqual(expectedResults, AddModel.results)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws {
|
||||
func testOutputTensorAtIndex_QuantizedModel() throws {
|
||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||
try setUpAddQuantizedModelInputTensor()
|
||||
try interpreter.invoke()
|
||||
@ -115,7 +117,7 @@ class InterpreterTests: XCTestCase {
|
||||
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
func testOutputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||
try interpreter.allocateTensors()
|
||||
try interpreter.invoke()
|
||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
|
||||
@ -127,18 +129,18 @@ class InterpreterTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
|
||||
func testOutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
|
||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
|
||||
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_ResizeInputTensorAtIndexToShape() {
|
||||
func testResizeInputTensorAtIndexToShape() {
|
||||
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
|
||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||
}
|
||||
|
||||
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
||||
func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
||||
XCTAssertThrowsError(try interpreter.resizeInput(
|
||||
at: AddModel.invalidIndex,
|
||||
to: [2, 2, 3]
|
||||
@ -151,14 +153,14 @@ class InterpreterTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex() throws {
|
||||
func testCopyDataToInputTensorAtIndex() throws {
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||
XCTAssertEqual(inputTensor.data, AddModel.inputData)
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
||||
func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
||||
XCTAssertThrowsError(try interpreter.copy(
|
||||
AddModel.inputData,
|
||||
toInputAt: AddModel.invalidIndex
|
||||
@ -171,7 +173,7 @@ class InterpreterTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
|
||||
func testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
|
||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||
try interpreter.allocateTensors()
|
||||
let invalidData = Data(count: AddModel.dataCount - 1)
|
||||
@ -186,7 +188,7 @@ class InterpreterTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
func testInterpreter_AllocateTensors() {
|
||||
func testAllocateTensors() {
|
||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||
}
|
||||
|
||||
@ -215,6 +217,33 @@ class InterpreterTests: XCTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
class InterpreterOptionsTests: XCTestCase {
|
||||
|
||||
func testInitWithDefaultValues() {
|
||||
let options = Interpreter.Options()
|
||||
XCTAssertNil(options.threadCount)
|
||||
}
|
||||
|
||||
func testInitWithCustomValues() {
|
||||
var options = Interpreter.Options()
|
||||
options.threadCount = 2
|
||||
XCTAssertEqual(options.threadCount, 2)
|
||||
}
|
||||
|
||||
func testEquatable() {
|
||||
var options1 = Interpreter.Options()
|
||||
var options2 = Interpreter.Options()
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options1.threadCount = 2
|
||||
options2.threadCount = 2
|
||||
XCTAssertEqual(options1, options2)
|
||||
|
||||
options2.threadCount = 3
|
||||
XCTAssertNotEqual(options1, options2)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Constants
|
||||
|
||||
/// Values for the `add.bin` model.
|
||||
@ -224,7 +253,7 @@ private enum AddModel {
|
||||
static let outputTensorCount = 1
|
||||
static let invalidIndex = 1
|
||||
static let validIndex = 0
|
||||
static let shape: TensorShape = [2]
|
||||
static let shape: Tensor.Shape = [2]
|
||||
static let dataCount = inputData.count
|
||||
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
|
||||
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
|
||||
@ -254,7 +283,7 @@ private enum AddModel {
|
||||
private enum AddQuantizedModel {
|
||||
static let info = (name: "add_quantized", extension: "bin")
|
||||
static let inputOutputIndex = 0
|
||||
static let shape: TensorShape = [2]
|
||||
static let shape: Tensor.Shape = [2]
|
||||
static let inputData = Data([1, 3])
|
||||
static let outputData = Data([3, 9])
|
||||
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
|
||||
|
@ -12,9 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
@testable import TensorFlowLite
|
||||
|
||||
class ModelTests: XCTestCase {
|
||||
|
||||
var modelPath: String!
|
||||
@ -39,15 +40,15 @@ class ModelTests: XCTestCase {
|
||||
super.tearDown()
|
||||
}
|
||||
|
||||
func testModel_InitWithFilePath() {
|
||||
func testInitWithFilePath() {
|
||||
XCTAssertNotNil(Model(filePath: modelPath))
|
||||
}
|
||||
|
||||
func testModel_InitWithEmptyFilePath_FailsInitialization() {
|
||||
func testInitWithEmptyFilePath_FailsInitialization() {
|
||||
XCTAssertNil(Model(filePath: ""))
|
||||
}
|
||||
|
||||
func testModel_InitWithInvalidFilePath_FailsInitialization() {
|
||||
func testInitWithInvalidFilePath_FailsInitialization() {
|
||||
XCTAssertNil(Model(filePath: "invalid/path"))
|
||||
}
|
||||
}
|
||||
|
@ -12,18 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
@testable import TensorFlowLite
|
||||
|
||||
class QuantizationParametersTests: XCTestCase {
|
||||
|
||||
func testQuantizationParameters_InitWithCustomValues() {
|
||||
func testInitWithCustomValues() {
|
||||
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
XCTAssertEqual(parameters.scale, 0.5)
|
||||
XCTAssertEqual(parameters.zeroPoint, 1)
|
||||
}
|
||||
|
||||
func testQuantizationParameters_Equatable() {
|
||||
func testEquatable() {
|
||||
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
XCTAssertEqual(parameters1, parameters2)
|
||||
|
@ -12,12 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
@testable import TensorFlowLite
|
||||
|
||||
class TensorFlowLiteTests: XCTestCase {
|
||||
|
||||
func testTensorFlowLite_Runtime_version() {
|
||||
func testRuntime_Version() {
|
||||
#if swift(>=5.0)
|
||||
let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?$"#
|
||||
#else
|
||||
|
@ -12,17 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
@testable import TensorFlowLite
|
||||
import XCTest
|
||||
|
||||
@testable import TensorFlowLite
|
||||
|
||||
class TensorTests: XCTestCase {
|
||||
|
||||
// MARK: - Tensor
|
||||
|
||||
func testTensor_Init() {
|
||||
func testInit() {
|
||||
let name = "InputTensor"
|
||||
let dataType: TensorDataType = .uInt8
|
||||
let shape = TensorShape(Constant.dimensions)
|
||||
let dataType: Tensor.DataType = .uInt8
|
||||
let shape = Tensor.Shape(Constant.dimensions)
|
||||
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
||||
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
let inputTensor = Tensor(
|
||||
@ -39,10 +38,10 @@ class TensorTests: XCTestCase {
|
||||
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
|
||||
}
|
||||
|
||||
func testTensor_Equatable() {
|
||||
func testEquatable() {
|
||||
let name = "Tensor"
|
||||
let dataType: TensorDataType = .uInt8
|
||||
let shape = TensorShape(Constant.dimensions)
|
||||
let dataType: Tensor.DataType = .uInt8
|
||||
let shape = Tensor.Shape(Constant.dimensions)
|
||||
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
||||
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||
let tensor1 = Tensor(
|
||||
@ -70,30 +69,31 @@ class TensorTests: XCTestCase {
|
||||
)
|
||||
XCTAssertNotEqual(tensor1, tensor2)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - TensorShape
|
||||
class TensorShapeTests: XCTestCase {
|
||||
|
||||
func testTensorShape_InitWithArray() {
|
||||
let shape = TensorShape(Constant.dimensions)
|
||||
func testInitWithArray() {
|
||||
let shape = Tensor.Shape(Constant.dimensions)
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
|
||||
func testTensorShape_InitWithElements() {
|
||||
let shape = TensorShape(2, 2, 3)
|
||||
func testInitWithElements() {
|
||||
let shape = Tensor.Shape(2, 2, 3)
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
|
||||
func testTensorShape_InitWithArrayLiteral() {
|
||||
let shape: TensorShape = [2, 2, 3]
|
||||
func testInitWithArrayLiteral() {
|
||||
let shape: Tensor.Shape = [2, 2, 3]
|
||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||
}
|
||||
|
||||
func testTensorShape_Equatable() {
|
||||
let shape1 = TensorShape(2, 2, 3)
|
||||
var shape2: TensorShape = [2, 2, 3]
|
||||
func testEquatable() {
|
||||
let shape1 = Tensor.Shape(2, 2, 3)
|
||||
var shape2: Tensor.Shape = [2, 2, 3]
|
||||
XCTAssertEqual(shape1, shape2)
|
||||
|
||||
shape2 = [2, 2, 4]
|
||||
|
Loading…
Reference in New Issue
Block a user