Updates InterpreterOptions, TensorShape, and TensorDataType to be nested types.

PiperOrigin-RevId: 265160862
This commit is contained in:
A. Unique TensorFlower 2019-08-23 16:36:29 -07:00 committed by TensorFlower Gardener
parent 2fce6eaf6d
commit c3a9eb7c28
14 changed files with 211 additions and 234 deletions

View File

@ -16,9 +16,9 @@ import TensorFlowLiteC
/// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations. /// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations.
public protocol Delegate: class { public protocol Delegate: class {
/// `TFL_Delegate` C pointer type. /// The `TfLiteDelegate` C pointer type.
typealias CDelegate = OpaquePointer typealias CDelegate = UnsafeMutablePointer<TfLiteDelegate>
/// Delegate that performs model computations. /// The delegate that performs model computations.
var cDelegate: CDelegate? { get } var cDelegate: CDelegate { get }
} }

View File

@ -17,32 +17,36 @@ import TensorFlowLiteC
/// A TensorFlow Lite interpreter that performs inference from a given model. /// A TensorFlow Lite interpreter that performs inference from a given model.
public final class Interpreter { public final class Interpreter {
/// `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`. /// The configuration options for the `Interpreter`.
private typealias CInterpreter = OpaquePointer public let options: Options?
/// Total number of input tensors associated with the model. /// The total number of input tensors associated with the model.
public var inputTensorCount: Int { public var inputTensorCount: Int {
return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter)) return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter))
} }
/// Total number of output tensors associated with the model. /// The total number of output tensors associated with the model.
public var outputTensorCount: Int { public var outputTensorCount: Int {
return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter)) return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter))
} }
/// Underlying `TfLiteInterpreter` C pointer. /// The `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TfLiteInterpreter>`.
private typealias CInterpreter = OpaquePointer
/// The underlying `TfLiteInterpreter` C pointer.
private var cInterpreter: CInterpreter? private var cInterpreter: CInterpreter?
/// Creates a new model interpreter instance. /// Creates a new instance with the given values.
/// ///
/// - Parameters: /// - Parameters:
/// - modelPath: Local file path to a TensorFlow Lite model. /// - modelPath: Local file path to a TensorFlow Lite model.
/// - options: Custom configurations for the interpreter. Default is `nil` indicating that the /// - options: Custom configuration options for the interpreter. Default is `nil` indicating
/// interpreter will determine the configuration options. /// that the interpreter will determine the configuration options.
/// - Throws: An error if the model could not be loaded or the interpreter could not be created. /// - Throws: An error if the model could not be loaded or the interpreter could not be created.
public init(modelPath: String, options: InterpreterOptions? = nil) throws { public init(modelPath: String, options: Options? = nil) throws {
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel } guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
self.options = options
let cInterpreterOptions: OpaquePointer? = try options.map { options in let cInterpreterOptions: OpaquePointer? = try options.map { options in
guard let cOptions = TfLiteInterpreterOptionsCreate() else { guard let cOptions = TfLiteInterpreterOptionsCreate() else {
throw InterpreterError.failedToCreateInterpreter throw InterpreterError.failedToCreateInterpreter
@ -105,14 +109,14 @@ public final class Interpreter {
else { else {
throw InterpreterError.allocateTensorsRequired throw InterpreterError.allocateTensorsRequired
} }
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else { guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
throw InterpreterError.invalidTensorDataType throw InterpreterError.invalidTensorDataType
} }
let name = String(cString: nameCString) let name = String(cString: nameCString)
let rank = TfLiteTensorNumDims(cTensor) let rank = TfLiteTensorNumDims(cTensor)
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) } let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
let shape = TensorShape(dimensions) let shape = Tensor.Shape(dimensions)
let byteCount = TfLiteTensorByteSize(cTensor) let byteCount = TfLiteTensorByteSize(cTensor)
let data = Data(bytes: bytes, count: byteCount) let data = Data(bytes: bytes, count: byteCount)
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor) let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
@ -151,14 +155,14 @@ public final class Interpreter {
else { else {
throw InterpreterError.invokeInterpreterRequired throw InterpreterError.invokeInterpreterRequired
} }
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else { guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
throw InterpreterError.invalidTensorDataType throw InterpreterError.invalidTensorDataType
} }
let name = String(cString: nameCString) let name = String(cString: nameCString)
let rank = TfLiteTensorNumDims(cTensor) let rank = TfLiteTensorNumDims(cTensor)
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) } let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
let shape = TensorShape(dimensions) let shape = Tensor.Shape(dimensions)
let byteCount = TfLiteTensorByteSize(cTensor) let byteCount = TfLiteTensorByteSize(cTensor)
let data = Data(bytes: bytes, count: byteCount) let data = Data(bytes: bytes, count: byteCount)
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor) let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
@ -187,7 +191,7 @@ public final class Interpreter {
/// - index: The index for the input tensor. /// - index: The index for the input tensor.
/// - shape: The shape that the input tensor should be resized to. /// - shape: The shape that the input tensor should be resized to.
/// - Throws: An error if the input tensor at the given index could not be resized. /// - Throws: An error if the input tensor at the given index could not be resized.
public func resizeInput(at index: Int, to shape: TensorShape) throws { public func resizeInput(at index: Int, to shape: Tensor.Shape) throws {
let maxIndex = inputTensorCount - 1 let maxIndex = inputTensorCount - 1
guard case 0...maxIndex = index else { guard case 0...maxIndex = index else {
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex) throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
@ -237,7 +241,7 @@ public final class Interpreter {
return try input(at: index) return try input(at: index)
} }
/// Allocates memory for all input tensors based on their `TensorShape`s. /// Allocates memory for all input tensors based on their `Tensor.Shape`s.
/// ///
/// - Note: This is a relatively expensive operation and should only be called after creating the /// - Note: This is a relatively expensive operation and should only be called after creating the
/// interpreter and/or resizing any input tensors. /// interpreter and/or resizing any input tensors.
@ -249,7 +253,17 @@ public final class Interpreter {
} }
} }
// MARK: - Extensions extension Interpreter {
/// Options for configuring the `Interpreter`.
public struct Options: Equatable, Hashable {
/// The maximum number of CPU threads that the interpreter should run on. Default is `nil`
/// indicating that the `Interpreter` will decide the number of threads to use.
public var threadCount: Int? = nil
/// Creates a new instance with the default values.
public init() {}
}
}
extension String { extension String {
/// Returns a new `String` initialized by using the given format C array as a template into which /// Returns a new `String` initialized by using the given format C array as a template into which

View File

@ -14,7 +14,7 @@
import Foundation import Foundation
/// TensorFlow Lite interpreter errors. /// Errors thrown by the TensorFlow Lite `Interpreter`.
public enum InterpreterError: Error, Equatable, Hashable { public enum InterpreterError: Error, Equatable, Hashable {
case invalidTensorIndex(index: Int, maxIndex: Int) case invalidTensorIndex(index: Int, maxIndex: Int)
case invalidTensorDataCount(provided: Int, required: Int) case invalidTensorDataCount(provided: Int, required: Int)
@ -29,10 +29,8 @@ public enum InterpreterError: Error, Equatable, Hashable {
case tensorFlowLiteError(String) case tensorFlowLiteError(String)
} }
// MARK: - Extensions
extension InterpreterError: LocalizedError { extension InterpreterError: LocalizedError {
/// Localized description of the interpreter error. /// A localized description of the interpreter error.
public var errorDescription: String? { public var errorDescription: String? {
switch self { switch self {
case .invalidTensorIndex(let index, let maxIndex): case .invalidTensorIndex(let index, let maxIndex):
@ -62,6 +60,6 @@ extension InterpreterError: LocalizedError {
} }
extension InterpreterError: CustomStringConvertible { extension InterpreterError: CustomStringConvertible {
/// Textual representation of the TensorFlow Lite interpreter error. /// A textual representation of the TensorFlow Lite interpreter error.
public var description: String { return errorDescription ?? "Unknown error." } public var description: String { return errorDescription ?? "Unknown error." }
} }

View File

@ -1,23 +0,0 @@
// Copyright 2018 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// Custom configuration options for a TensorFlow Lite `Interpreter`.
public struct InterpreterOptions: Equatable {
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` indicating
/// that the `Interpreter` will decide the number of threads to use.
public var threadCount: Int? = nil
/// Creates a new instance of interpreter options.
public init() {}
}

View File

@ -14,15 +14,15 @@
import TensorFlowLiteC import TensorFlowLiteC
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference. /// A TensorFlow Lite model used by the `Interpreter` to perform inference.
final class Model { final class Model {
/// `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`. /// The `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`.
typealias CModel = OpaquePointer typealias CModel = OpaquePointer
/// Underlying `TfLiteModel` C pointer. /// The underlying `TfLiteModel` C pointer.
let cModel: CModel? let cModel: CModel?
/// Creates a new model instance. /// Creates a new instance with the given `filePath`.
/// ///
/// - Precondition: Initialization can fail if the given `filePath` is invalid. /// - Precondition: Initialization can fail if the given `filePath` is invalid.
/// - Parameters: /// - Parameters:

View File

@ -16,11 +16,11 @@
/// be mapped to float values using the following conversion: /// be mapped to float values using the following conversion:
/// `realValue = scale * (quantizedValue - zeroPoint)`. /// `realValue = scale * (quantizedValue - zeroPoint)`.
public struct QuantizationParameters: Equatable, Hashable { public struct QuantizationParameters: Equatable, Hashable {
/// Difference between real values corresponding to consecutive quantized values differing by 1. /// The difference between real values corresponding to consecutive quantized values differing by
/// For example, the range of quantized values for `UInt8` data type is [0, 255]. /// 1. For example, the range of quantized values for `UInt8` data type is [0, 255].
public let scale: Float public let scale: Float
/// Quantized value that corresponds to the real 0 value. /// The quantized value that corresponds to the real 0 value.
public let zeroPoint: Int public let zeroPoint: Int
/// Creates a new quantization parameters instance. /// Creates a new quantization parameters instance.

View File

@ -17,19 +17,19 @@ import TensorFlowLiteC
/// An input or output tensor in a TensorFlow Lite graph. /// An input or output tensor in a TensorFlow Lite graph.
public struct Tensor: Equatable, Hashable { public struct Tensor: Equatable, Hashable {
/// Name of the tensor. /// The name of the tensor.
public let name: String public let name: String
/// Data type of the tensor. /// The data type of the tensor.
public let dataType: TensorDataType public let dataType: DataType
/// Shape of the tensor. /// The shape of the tensor.
public let shape: TensorShape public let shape: Shape
/// Data in the input or output tensor. /// The data in the input or output tensor.
public let data: Data public let data: Data
/// Quantization parameters for the tensor if using a quantized model. /// The quantization parameters for the tensor if using a quantized model.
public let quantizationParameters: QuantizationParameters? public let quantizationParameters: QuantizationParameters?
/// Creates a new input or output tensor instance. /// Creates a new input or output tensor instance.
@ -43,8 +43,8 @@ public struct Tensor: Equatable, Hashable {
/// Default is `nil`. /// Default is `nil`.
init( init(
name: String, name: String,
dataType: TensorDataType, dataType: DataType,
shape: TensorShape, shape: Shape,
data: Data, data: Data,
quantizationParameters: QuantizationParameters? = nil quantizationParameters: QuantizationParameters? = nil
) { ) {
@ -56,83 +56,86 @@ public struct Tensor: Equatable, Hashable {
} }
} }
/// Supported TensorFlow Lite tensor data types. extension Tensor {
public enum TensorDataType: Equatable, Hashable { /// The supported `Tensor` data types.
/// Boolean. public enum DataType: Equatable, Hashable {
case bool /// A boolean.
/// 8-bit unsigned integer. case bool
case uInt8 /// An 8-bit unsigned integer.
/// 16-bit signed integer. case uInt8
case int16 /// A 16-bit signed integer.
/// 32-bit signed integer. case int16
case int32 /// A 32-bit signed integer.
/// 64-bit signed integer. case int32
case int64 /// A 64-bit signed integer.
/// 16-bit half precision floating point. case int64
case float16 /// A 16-bit half precision floating point.
/// 32-bit single precision floating point. case float16
case float32 /// A 32-bit single precision floating point.
case float32
/// Creates a new tensor data type from the given `TfLiteType` or `nil` if the data type is /// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported
/// unsupported or could not be determined because there was an error. /// or could not be determined because there was an error.
/// ///
/// - Parameter type: A data type supported by a tensor. /// - Parameter type: Data type supported by a tensor.
init?(type: TfLiteType) { init?(type: TfLiteType) {
switch type { switch type {
case kTfLiteBool: case kTfLiteBool:
self = .bool self = .bool
case kTfLiteUInt8: case kTfLiteUInt8:
self = .uInt8 self = .uInt8
case kTfLiteInt16: case kTfLiteInt16:
self = .int16 self = .int16
case kTfLiteInt32: case kTfLiteInt32:
self = .int32 self = .int32
case kTfLiteInt64: case kTfLiteInt64:
self = .int64 self = .int64
case kTfLiteFloat16: case kTfLiteFloat16:
self = .float16 self = .float16
case kTfLiteFloat32: case kTfLiteFloat32:
self = .float32 self = .float32
case kTfLiteNoType: case kTfLiteNoType:
fallthrough fallthrough
default: default:
return nil return nil
}
} }
} }
} }
/// The shape of a TensorFlow Lite tensor. extension Tensor {
public struct TensorShape: Equatable, Hashable { /// The shape of a `Tensor`.
public struct Shape: Equatable, Hashable {
/// The number of dimensions of the tensor.
public let rank: Int
/// The number of dimensions of the tensor. /// An array of dimensions for the tensor.
public let rank: Int public let dimensions: [Int]
/// Array of dimensions for the tensor. /// An array of `Int32` dimensions for the tensor.
public let dimensions: [Int] var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
/// Array of `Int32` dimensions for the tensor. /// Creates a new instance with the given array of dimensions.
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) } ///
/// - Parameters:
/// - dimensions: Dimensions for the tensor.
public init(_ dimensions: [Int]) {
self.rank = dimensions.count
self.dimensions = dimensions
}
/// Creates a new tensor shape instance with the given array of dimensions. /// Creates a new instance with the given elements representing the dimensions.
/// ///
/// - Parameters: /// - Parameters:
/// - dimensions: Dimensions for the tensor. /// - elements: Dimensions for the tensor.
public init(_ dimensions: [Int]) { public init(_ elements: Int...) {
self.rank = dimensions.count self.init(elements)
self.dimensions = dimensions }
}
/// Creates a new tensor shape instance with the given elements representing the dimensions.
///
/// - Parameters:
/// - elements: Dimensions for the tensor.
public init(_ elements: Int...) {
self.init(elements)
} }
} }
extension TensorShape: ExpressibleByArrayLiteral { extension Tensor.Shape: ExpressibleByArrayLiteral {
/// Creates a new tensor shape instance with the given array literal representing the dimensions. /// Creates a new instance with the given array literal representing the dimensions.
/// ///
/// - Parameters: /// - Parameters:
/// - arrayLiteral: Dimensions for the tensor. /// - arrayLiteral: Dimensions for the tensor.

View File

@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
import class TensorFlowLite.Interpreter import TensorFlowLite
import struct TensorFlowLite.InterpreterOptions
import struct TensorFlowLite.Tensor
import struct TensorFlowLite.TensorShape
import enum TensorFlowLite.Runtime
import UIKit import UIKit
class ViewController: UIViewController { class ViewController: UIViewController {
// MARK: - Properties // MARK: - Properties
/// TensorFlowLite interpreter object for performing inference from a given model. /// TensorFlow Lite interpreter object for performing inference from a given model.
private var interpreter: Interpreter? private var interpreter: Interpreter?
/// Serial dispatch queue for managing `Interpreter` calls. /// Serial dispatch queue for managing `Interpreter` calls.
@ -122,7 +118,7 @@ class ViewController: UIViewController {
private func setUpInterpreter(withModelPath modelPath: String) { private func setUpInterpreter(withModelPath modelPath: String) {
interpreterQueue.async { interpreterQueue.async {
do { do {
var options = InterpreterOptions() var options = Interpreter.Options()
options.threadCount = 2 options.threadCount = 2
self.interpreter = try Interpreter(modelPath: modelPath, options: options) self.interpreter = try Interpreter(modelPath: modelPath, options: options)
} catch let error { } catch let error {
@ -211,7 +207,7 @@ class ViewController: UIViewController {
return return
} }
do { do {
let shape = TensorShape(2) let shape = Tensor.Shape(2)
try (0..<interpreter.inputTensorCount).forEach { index in try (0..<interpreter.inputTensorCount).forEach { index in
try interpreter.resizeInput(at: index, to: shape) try interpreter.resizeInput(at: index, to: shape)
} }

View File

@ -1,43 +0,0 @@
// Copyright 2018 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
class InterpreterOptionsTests: XCTestCase {
func testInterpreterOptions_InitWithDefaultValues() {
let options = InterpreterOptions()
XCTAssertNil(options.threadCount)
}
func testInterpreterOptions_InitWithCustomValues() {
var options = InterpreterOptions()
options.threadCount = 2
XCTAssertEqual(options.threadCount, 2)
}
func testInterpreterOptions_Equatable() {
var options1 = InterpreterOptions()
var options2 = InterpreterOptions()
XCTAssertEqual(options1, options2)
options1.threadCount = 2
options2.threadCount = 2
XCTAssertEqual(options1, options2)
options2.threadCount = 3
XCTAssertNotEqual(options1, options2)
}
}

View File

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
@testable import TensorFlowLite
import XCTest import XCTest
@testable import TensorFlowLite
class InterpreterTests: XCTestCase { class InterpreterTests: XCTestCase {
var interpreter: Interpreter! var interpreter: Interpreter!
@ -31,55 +32,56 @@ class InterpreterTests: XCTestCase {
super.tearDown() super.tearDown()
} }
func testInterpreter_InitWithModelPath() { func testInitWithModelPath() {
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path)) XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
} }
func testInterpreter_Init_ThrowsFailedToLoadModel() { func testInit_ThrowsFailedToLoadModel() {
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
self.assertEqualErrors(actual: error, expected: .failedToLoadModel) self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
} }
} }
func testInterpreter_InitWithModelPathAndOptions() { func testInitWithModelPathAndOptions() throws {
var options = InterpreterOptions() var options = Interpreter.Options()
options.threadCount = 2 options.threadCount = 2
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options)) let interpreter = try Interpreter(modelPath: AddModel.path, options: options)
XCTAssertNotNil(interpreter.options)
} }
func testInterpreter_InputTensorCount() { func testInputTensorCount() {
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount) XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
} }
func testInterpreter_OutputTensorCount() { func testOutputTensorCount() {
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount) XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
} }
func testInterpreter_Invoke() throws { func testInvoke() throws {
try interpreter.allocateTensors() try interpreter.allocateTensors()
XCTAssertNoThrow(try interpreter.invoke()) XCTAssertNoThrow(try interpreter.invoke())
} }
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() { func testInvoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
XCTAssertThrowsError(try interpreter.invoke()) { error in XCTAssertThrowsError(try interpreter.invoke()) { error in
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired) self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
} }
} }
func testInterpreter_InputTensorAtIndex() throws { func testInputTensorAtIndex() throws {
try setUpAddModelInputTensor() try setUpAddModelInputTensor()
let inputTensor = try interpreter.input(at: AddModel.validIndex) let inputTensor = try interpreter.input(at: AddModel.validIndex)
XCTAssertEqual(inputTensor, AddModel.inputTensor) XCTAssertEqual(inputTensor, AddModel.inputTensor)
} }
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws { func testInputTensorAtIndex_QuantizedModel() throws {
interpreter = try Interpreter(modelPath: AddQuantizedModel.path) interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
try setUpAddQuantizedModelInputTensor() try setUpAddQuantizedModelInputTensor()
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex) let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor) XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
} }
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws { func testInputTensorAtIndex_ThrowsInvalidIndex() throws {
try interpreter.allocateTensors() try interpreter.allocateTensors()
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
let maxIndex = AddModel.inputTensorCount - 1 let maxIndex = AddModel.inputTensorCount - 1
@ -90,13 +92,13 @@ class InterpreterTests: XCTestCase {
} }
} }
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() { func testInputTensorAtIndex_ThrowsAllocateTensorsRequired() {
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired) self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
} }
} }
func testInterpreter_OutputTensorAtIndex() throws { func testOutputTensorAtIndex() throws {
try setUpAddModelInputTensor() try setUpAddModelInputTensor()
try interpreter.invoke() try interpreter.invoke()
let outputTensor = try interpreter.output(at: AddModel.validIndex) let outputTensor = try interpreter.output(at: AddModel.validIndex)
@ -105,7 +107,7 @@ class InterpreterTests: XCTestCase {
XCTAssertEqual(expectedResults, AddModel.results) XCTAssertEqual(expectedResults, AddModel.results)
} }
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws { func testOutputTensorAtIndex_QuantizedModel() throws {
interpreter = try Interpreter(modelPath: AddQuantizedModel.path) interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
try setUpAddQuantizedModelInputTensor() try setUpAddQuantizedModelInputTensor()
try interpreter.invoke() try interpreter.invoke()
@ -115,7 +117,7 @@ class InterpreterTests: XCTestCase {
XCTAssertEqual(expectedResults, AddQuantizedModel.results) XCTAssertEqual(expectedResults, AddQuantizedModel.results)
} }
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws { func testOutputTensorAtIndex_ThrowsInvalidIndex() throws {
try interpreter.allocateTensors() try interpreter.allocateTensors()
try interpreter.invoke() try interpreter.invoke()
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
@ -127,18 +129,18 @@ class InterpreterTests: XCTestCase {
} }
} }
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() { func testOutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired) self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
} }
} }
func testInterpreter_ResizeInputTensorAtIndexToShape() { func testResizeInputTensorAtIndexToShape() {
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3])) XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
XCTAssertNoThrow(try interpreter.allocateTensors()) XCTAssertNoThrow(try interpreter.allocateTensors())
} }
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() { func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
XCTAssertThrowsError(try interpreter.resizeInput( XCTAssertThrowsError(try interpreter.resizeInput(
at: AddModel.invalidIndex, at: AddModel.invalidIndex,
to: [2, 2, 3] to: [2, 2, 3]
@ -151,14 +153,14 @@ class InterpreterTests: XCTestCase {
} }
} }
func testInterpreter_CopyDataToInputTensorAtIndex() throws { func testCopyDataToInputTensorAtIndex() throws {
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
try interpreter.allocateTensors() try interpreter.allocateTensors()
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex) let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
XCTAssertEqual(inputTensor.data, AddModel.inputData) XCTAssertEqual(inputTensor.data, AddModel.inputData)
} }
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() { func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
XCTAssertThrowsError(try interpreter.copy( XCTAssertThrowsError(try interpreter.copy(
AddModel.inputData, AddModel.inputData,
toInputAt: AddModel.invalidIndex toInputAt: AddModel.invalidIndex
@ -171,7 +173,7 @@ class InterpreterTests: XCTestCase {
} }
} }
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws { func testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape) try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
try interpreter.allocateTensors() try interpreter.allocateTensors()
let invalidData = Data(count: AddModel.dataCount - 1) let invalidData = Data(count: AddModel.dataCount - 1)
@ -186,7 +188,7 @@ class InterpreterTests: XCTestCase {
} }
} }
func testInterpreter_AllocateTensors() { func testAllocateTensors() {
XCTAssertNoThrow(try interpreter.allocateTensors()) XCTAssertNoThrow(try interpreter.allocateTensors())
} }
@ -215,6 +217,33 @@ class InterpreterTests: XCTestCase {
} }
} }
class InterpreterOptionsTests: XCTestCase {
func testInitWithDefaultValues() {
let options = Interpreter.Options()
XCTAssertNil(options.threadCount)
}
func testInitWithCustomValues() {
var options = Interpreter.Options()
options.threadCount = 2
XCTAssertEqual(options.threadCount, 2)
}
func testEquatable() {
var options1 = Interpreter.Options()
var options2 = Interpreter.Options()
XCTAssertEqual(options1, options2)
options1.threadCount = 2
options2.threadCount = 2
XCTAssertEqual(options1, options2)
options2.threadCount = 3
XCTAssertNotEqual(options1, options2)
}
}
// MARK: - Constants // MARK: - Constants
/// Values for the `add.bin` model. /// Values for the `add.bin` model.
@ -224,7 +253,7 @@ private enum AddModel {
static let outputTensorCount = 1 static let outputTensorCount = 1
static let invalidIndex = 1 static let invalidIndex = 1
static let validIndex = 0 static let validIndex = 0
static let shape: TensorShape = [2] static let shape: Tensor.Shape = [2]
static let dataCount = inputData.count static let dataCount = inputData.count
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)]) static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)]) static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
@ -254,7 +283,7 @@ private enum AddModel {
private enum AddQuantizedModel { private enum AddQuantizedModel {
static let info = (name: "add_quantized", extension: "bin") static let info = (name: "add_quantized", extension: "bin")
static let inputOutputIndex = 0 static let inputOutputIndex = 0
static let shape: TensorShape = [2] static let shape: Tensor.Shape = [2]
static let inputData = Data([1, 3]) static let inputData = Data([1, 3])
static let outputData = Data([3, 9]) static let outputData = Data([3, 9])
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0) static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)

View File

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
@testable import TensorFlowLite
import XCTest import XCTest
@testable import TensorFlowLite
class ModelTests: XCTestCase { class ModelTests: XCTestCase {
var modelPath: String! var modelPath: String!
@ -39,15 +40,15 @@ class ModelTests: XCTestCase {
super.tearDown() super.tearDown()
} }
func testModel_InitWithFilePath() { func testInitWithFilePath() {
XCTAssertNotNil(Model(filePath: modelPath)) XCTAssertNotNil(Model(filePath: modelPath))
} }
func testModel_InitWithEmptyFilePath_FailsInitialization() { func testInitWithEmptyFilePath_FailsInitialization() {
XCTAssertNil(Model(filePath: "")) XCTAssertNil(Model(filePath: ""))
} }
func testModel_InitWithInvalidFilePath_FailsInitialization() { func testInitWithInvalidFilePath_FailsInitialization() {
XCTAssertNil(Model(filePath: "invalid/path")) XCTAssertNil(Model(filePath: "invalid/path"))
} }
} }

View File

@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
@testable import TensorFlowLite
import XCTest import XCTest
@testable import TensorFlowLite
class QuantizationParametersTests: XCTestCase { class QuantizationParametersTests: XCTestCase {
func testQuantizationParameters_InitWithCustomValues() { func testInitWithCustomValues() {
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1) let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
XCTAssertEqual(parameters.scale, 0.5) XCTAssertEqual(parameters.scale, 0.5)
XCTAssertEqual(parameters.zeroPoint, 1) XCTAssertEqual(parameters.zeroPoint, 1)
} }
func testQuantizationParameters_Equatable() { func testEquatable() {
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1) let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1) let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
XCTAssertEqual(parameters1, parameters2) XCTAssertEqual(parameters1, parameters2)

View File

@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
@testable import TensorFlowLite
import XCTest import XCTest
@testable import TensorFlowLite
class TensorFlowLiteTests: XCTestCase { class TensorFlowLiteTests: XCTestCase {
func testTensorFlowLite_Runtime_version() { func testRuntime_Version() {
#if swift(>=5.0) #if swift(>=5.0)
let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?$"# let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?$"#
#else #else

View File

@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
@testable import TensorFlowLite
import XCTest import XCTest
@testable import TensorFlowLite
class TensorTests: XCTestCase { class TensorTests: XCTestCase {
// MARK: - Tensor func testInit() {
func testTensor_Init() {
let name = "InputTensor" let name = "InputTensor"
let dataType: TensorDataType = .uInt8 let dataType: Tensor.DataType = .uInt8
let shape = TensorShape(Constant.dimensions) let shape = Tensor.Shape(Constant.dimensions)
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return } guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1) let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let inputTensor = Tensor( let inputTensor = Tensor(
@ -39,10 +38,10 @@ class TensorTests: XCTestCase {
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters) XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
} }
func testTensor_Equatable() { func testEquatable() {
let name = "Tensor" let name = "Tensor"
let dataType: TensorDataType = .uInt8 let dataType: Tensor.DataType = .uInt8
let shape = TensorShape(Constant.dimensions) let shape = Tensor.Shape(Constant.dimensions)
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return } guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1) let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let tensor1 = Tensor( let tensor1 = Tensor(
@ -70,30 +69,31 @@ class TensorTests: XCTestCase {
) )
XCTAssertNotEqual(tensor1, tensor2) XCTAssertNotEqual(tensor1, tensor2)
} }
}
// MARK: - TensorShape class TensorShapeTests: XCTestCase {
func testTensorShape_InitWithArray() { func testInitWithArray() {
let shape = TensorShape(Constant.dimensions) let shape = Tensor.Shape(Constant.dimensions)
XCTAssertEqual(shape.rank, Constant.dimensions.count) XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions) XCTAssertEqual(shape.dimensions, Constant.dimensions)
} }
func testTensorShape_InitWithElements() { func testInitWithElements() {
let shape = TensorShape(2, 2, 3) let shape = Tensor.Shape(2, 2, 3)
XCTAssertEqual(shape.rank, Constant.dimensions.count) XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions) XCTAssertEqual(shape.dimensions, Constant.dimensions)
} }
func testTensorShape_InitWithArrayLiteral() { func testInitWithArrayLiteral() {
let shape: TensorShape = [2, 2, 3] let shape: Tensor.Shape = [2, 2, 3]
XCTAssertEqual(shape.rank, Constant.dimensions.count) XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions) XCTAssertEqual(shape.dimensions, Constant.dimensions)
} }
func testTensorShape_Equatable() { func testEquatable() {
let shape1 = TensorShape(2, 2, 3) let shape1 = Tensor.Shape(2, 2, 3)
var shape2: TensorShape = [2, 2, 3] var shape2: Tensor.Shape = [2, 2, 3]
XCTAssertEqual(shape1, shape2) XCTAssertEqual(shape1, shape2)
shape2 = [2, 2, 4] shape2 = [2, 2, 4]