Updates InterpreterOptions, TensorShape, and TensorDataType to be nested types.

PiperOrigin-RevId: 265160862
This commit is contained in:
A. Unique TensorFlower 2019-08-23 16:36:29 -07:00 committed by TensorFlower Gardener
parent 2fce6eaf6d
commit c3a9eb7c28
14 changed files with 211 additions and 234 deletions

View File

@ -16,9 +16,9 @@ import TensorFlowLiteC
/// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations.
public protocol Delegate: class {
/// `TFL_Delegate` C pointer type.
typealias CDelegate = OpaquePointer
/// The `TfLiteDelegate` C pointer type.
typealias CDelegate = UnsafeMutablePointer<TfLiteDelegate>
/// Delegate that performs model computations.
var cDelegate: CDelegate? { get }
/// The delegate that performs model computations.
var cDelegate: CDelegate { get }
}

View File

@ -17,32 +17,36 @@ import TensorFlowLiteC
/// A TensorFlow Lite interpreter that performs inference from a given model.
public final class Interpreter {
/// `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
private typealias CInterpreter = OpaquePointer
/// The configuration options for the `Interpreter`.
public let options: Options?
/// Total number of input tensors associated with the model.
/// The total number of input tensors associated with the model.
public var inputTensorCount: Int {
return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter))
}
/// Total number of output tensors associated with the model.
/// The total number of output tensors associated with the model.
public var outputTensorCount: Int {
return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter))
}
/// Underlying `TfLiteInterpreter` C pointer.
/// The `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TfLiteInterpreter>`.
private typealias CInterpreter = OpaquePointer
/// The underlying `TfLiteInterpreter` C pointer.
private var cInterpreter: CInterpreter?
/// Creates a new model interpreter instance.
/// Creates a new instance with the given values.
///
/// - Parameters:
/// - modelPath: Local file path to a TensorFlow Lite model.
/// - options: Custom configurations for the interpreter. Default is `nil` indicating that the
/// interpreter will determine the configuration options.
/// - options: Custom configuration options for the interpreter. Default is `nil` indicating
/// that the interpreter will determine the configuration options.
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
public init(modelPath: String, options: Options? = nil) throws {
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
self.options = options
let cInterpreterOptions: OpaquePointer? = try options.map { options in
guard let cOptions = TfLiteInterpreterOptionsCreate() else {
throw InterpreterError.failedToCreateInterpreter
@ -105,14 +109,14 @@ public final class Interpreter {
else {
throw InterpreterError.allocateTensorsRequired
}
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else {
guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
throw InterpreterError.invalidTensorDataType
}
let name = String(cString: nameCString)
let rank = TfLiteTensorNumDims(cTensor)
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
let shape = TensorShape(dimensions)
let shape = Tensor.Shape(dimensions)
let byteCount = TfLiteTensorByteSize(cTensor)
let data = Data(bytes: bytes, count: byteCount)
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
@ -151,14 +155,14 @@ public final class Interpreter {
else {
throw InterpreterError.invokeInterpreterRequired
}
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else {
guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
throw InterpreterError.invalidTensorDataType
}
let name = String(cString: nameCString)
let rank = TfLiteTensorNumDims(cTensor)
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
let shape = TensorShape(dimensions)
let shape = Tensor.Shape(dimensions)
let byteCount = TfLiteTensorByteSize(cTensor)
let data = Data(bytes: bytes, count: byteCount)
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
@ -187,7 +191,7 @@ public final class Interpreter {
/// - index: The index for the input tensor.
/// - shape: The shape that the input tensor should be resized to.
/// - Throws: An error if the input tensor at the given index could not be resized.
public func resizeInput(at index: Int, to shape: TensorShape) throws {
public func resizeInput(at index: Int, to shape: Tensor.Shape) throws {
let maxIndex = inputTensorCount - 1
guard case 0...maxIndex = index else {
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
@ -237,7 +241,7 @@ public final class Interpreter {
return try input(at: index)
}
/// Allocates memory for all input tensors based on their `TensorShape`s.
/// Allocates memory for all input tensors based on their `Tensor.Shape`s.
///
/// - Note: This is a relatively expensive operation and should only be called after creating the
/// interpreter and/or resizing any input tensors.
@ -249,7 +253,17 @@ public final class Interpreter {
}
}
// MARK: - Extensions
extension Interpreter {
/// Options for configuring the `Interpreter`.
public struct Options: Equatable, Hashable {
/// The maximum number of CPU threads that the interpreter should run on. Default is `nil`
/// indicating that the `Interpreter` will decide the number of threads to use.
public var threadCount: Int? = nil
/// Creates a new instance with the default values.
public init() {}
}
}
extension String {
/// Returns a new `String` initialized by using the given format C array as a template into which

View File

@ -14,7 +14,7 @@
import Foundation
/// TensorFlow Lite interpreter errors.
/// Errors thrown by the TensorFlow Lite `Interpreter`.
public enum InterpreterError: Error, Equatable, Hashable {
case invalidTensorIndex(index: Int, maxIndex: Int)
case invalidTensorDataCount(provided: Int, required: Int)
@ -29,10 +29,8 @@ public enum InterpreterError: Error, Equatable, Hashable {
case tensorFlowLiteError(String)
}
// MARK: - Extensions
extension InterpreterError: LocalizedError {
/// Localized description of the interpreter error.
/// A localized description of the interpreter error.
public var errorDescription: String? {
switch self {
case .invalidTensorIndex(let index, let maxIndex):
@ -62,6 +60,6 @@ extension InterpreterError: LocalizedError {
}
extension InterpreterError: CustomStringConvertible {
/// Textual representation of the TensorFlow Lite interpreter error.
/// A textual representation of the TensorFlow Lite interpreter error.
public var description: String { return errorDescription ?? "Unknown error." }
}

View File

@ -1,23 +0,0 @@
// Copyright 2018 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// Custom configuration options for a TensorFlow Lite `Interpreter`.
public struct InterpreterOptions: Equatable {
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` indicating
/// that the `Interpreter` will decide the number of threads to use.
public var threadCount: Int? = nil
/// Creates a new instance of interpreter options.
public init() {}
}

View File

@ -14,15 +14,15 @@
import TensorFlowLiteC
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
/// A TensorFlow Lite model used by the `Interpreter` to perform inference.
final class Model {
/// `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`.
/// The `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`.
typealias CModel = OpaquePointer
/// Underlying `TfLiteModel` C pointer.
/// The underlying `TfLiteModel` C pointer.
let cModel: CModel?
/// Creates a new model instance.
/// Creates a new instance with the given `filePath`.
///
/// - Precondition: Initialization can fail if the given `filePath` is invalid.
/// - Parameters:

View File

@ -16,11 +16,11 @@
/// be mapped to float values using the following conversion:
/// `realValue = scale * (quantizedValue - zeroPoint)`.
public struct QuantizationParameters: Equatable, Hashable {
/// Difference between real values corresponding to consecutive quantized values differing by 1.
/// For example, the range of quantized values for `UInt8` data type is [0, 255].
/// The difference between real values corresponding to consecutive quantized values differing by
/// 1. For example, the range of quantized values for `UInt8` data type is [0, 255].
public let scale: Float
/// Quantized value that corresponds to the real 0 value.
/// The quantized value that corresponds to the real 0 value.
public let zeroPoint: Int
/// Creates a new quantization parameters instance.

View File

@ -17,19 +17,19 @@ import TensorFlowLiteC
/// An input or output tensor in a TensorFlow Lite graph.
public struct Tensor: Equatable, Hashable {
/// Name of the tensor.
/// The name of the tensor.
public let name: String
/// Data type of the tensor.
public let dataType: TensorDataType
/// The data type of the tensor.
public let dataType: DataType
/// Shape of the tensor.
public let shape: TensorShape
/// The shape of the tensor.
public let shape: Shape
/// Data in the input or output tensor.
/// The data in the input or output tensor.
public let data: Data
/// Quantization parameters for the tensor if using a quantized model.
/// The quantization parameters for the tensor if using a quantized model.
public let quantizationParameters: QuantizationParameters?
/// Creates a new input or output tensor instance.
@ -43,8 +43,8 @@ public struct Tensor: Equatable, Hashable {
/// Default is `nil`.
init(
name: String,
dataType: TensorDataType,
shape: TensorShape,
dataType: DataType,
shape: Shape,
data: Data,
quantizationParameters: QuantizationParameters? = nil
) {
@ -56,83 +56,86 @@ public struct Tensor: Equatable, Hashable {
}
}
/// Supported TensorFlow Lite tensor data types.
public enum TensorDataType: Equatable, Hashable {
/// Boolean.
case bool
/// 8-bit unsigned integer.
case uInt8
/// 16-bit signed integer.
case int16
/// 32-bit signed integer.
case int32
/// 64-bit signed integer.
case int64
/// 16-bit half precision floating point.
case float16
/// 32-bit single precision floating point.
case float32
extension Tensor {
/// The supported `Tensor` data types.
public enum DataType: Equatable, Hashable {
/// A boolean.
case bool
/// An 8-bit unsigned integer.
case uInt8
/// A 16-bit signed integer.
case int16
/// A 32-bit signed integer.
case int32
/// A 64-bit signed integer.
case int64
/// A 16-bit half precision floating point.
case float16
/// A 32-bit single precision floating point.
case float32
/// Creates a new tensor data type from the given `TfLiteType` or `nil` if the data type is
/// unsupported or could not be determined because there was an error.
///
/// - Parameter type: A data type supported by a tensor.
init?(type: TfLiteType) {
switch type {
case kTfLiteBool:
self = .bool
case kTfLiteUInt8:
self = .uInt8
case kTfLiteInt16:
self = .int16
case kTfLiteInt32:
self = .int32
case kTfLiteInt64:
self = .int64
case kTfLiteFloat16:
self = .float16
case kTfLiteFloat32:
self = .float32
case kTfLiteNoType:
fallthrough
default:
return nil
/// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported
/// or could not be determined because there was an error.
///
/// - Parameter type: Data type supported by a tensor.
init?(type: TfLiteType) {
switch type {
case kTfLiteBool:
self = .bool
case kTfLiteUInt8:
self = .uInt8
case kTfLiteInt16:
self = .int16
case kTfLiteInt32:
self = .int32
case kTfLiteInt64:
self = .int64
case kTfLiteFloat16:
self = .float16
case kTfLiteFloat32:
self = .float32
case kTfLiteNoType:
fallthrough
default:
return nil
}
}
}
}
/// The shape of a TensorFlow Lite tensor.
public struct TensorShape: Equatable, Hashable {
extension Tensor {
/// The shape of a `Tensor`.
public struct Shape: Equatable, Hashable {
/// The number of dimensions of the tensor.
public let rank: Int
/// The number of dimensions of the tensor.
public let rank: Int
/// An array of dimensions for the tensor.
public let dimensions: [Int]
/// Array of dimensions for the tensor.
public let dimensions: [Int]
/// An array of `Int32` dimensions for the tensor.
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
/// Array of `Int32` dimensions for the tensor.
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
/// Creates a new instance with the given array of dimensions.
///
/// - Parameters:
/// - dimensions: Dimensions for the tensor.
public init(_ dimensions: [Int]) {
self.rank = dimensions.count
self.dimensions = dimensions
}
/// Creates a new tensor shape instance with the given array of dimensions.
///
/// - Parameters:
/// - dimensions: Dimensions for the tensor.
public init(_ dimensions: [Int]) {
self.rank = dimensions.count
self.dimensions = dimensions
}
/// Creates a new tensor shape instance with the given elements representing the dimensions.
///
/// - Parameters:
/// - elements: Dimensions for the tensor.
public init(_ elements: Int...) {
self.init(elements)
/// Creates a new instance with the given elements representing the dimensions.
///
/// - Parameters:
/// - elements: Dimensions for the tensor.
public init(_ elements: Int...) {
self.init(elements)
}
}
}
extension TensorShape: ExpressibleByArrayLiteral {
/// Creates a new tensor shape instance with the given array literal representing the dimensions.
extension Tensor.Shape: ExpressibleByArrayLiteral {
/// Creates a new instance with the given array literal representing the dimensions.
///
/// - Parameters:
/// - arrayLiteral: Dimensions for the tensor.

View File

@ -12,18 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import class TensorFlowLite.Interpreter
import struct TensorFlowLite.InterpreterOptions
import struct TensorFlowLite.Tensor
import struct TensorFlowLite.TensorShape
import enum TensorFlowLite.Runtime
import TensorFlowLite
import UIKit
class ViewController: UIViewController {
// MARK: - Properties
/// TensorFlowLite interpreter object for performing inference from a given model.
/// TensorFlow Lite interpreter object for performing inference from a given model.
private var interpreter: Interpreter?
/// Serial dispatch queue for managing `Interpreter` calls.
@ -122,7 +118,7 @@ class ViewController: UIViewController {
private func setUpInterpreter(withModelPath modelPath: String) {
interpreterQueue.async {
do {
var options = InterpreterOptions()
var options = Interpreter.Options()
options.threadCount = 2
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
} catch let error {
@ -211,7 +207,7 @@ class ViewController: UIViewController {
return
}
do {
let shape = TensorShape(2)
let shape = Tensor.Shape(2)
try (0..<interpreter.inputTensorCount).forEach { index in
try interpreter.resizeInput(at: index, to: shape)
}

View File

@ -1,43 +0,0 @@
// Copyright 2018 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
class InterpreterOptionsTests: XCTestCase {
func testInterpreterOptions_InitWithDefaultValues() {
let options = InterpreterOptions()
XCTAssertNil(options.threadCount)
}
func testInterpreterOptions_InitWithCustomValues() {
var options = InterpreterOptions()
options.threadCount = 2
XCTAssertEqual(options.threadCount, 2)
}
func testInterpreterOptions_Equatable() {
var options1 = InterpreterOptions()
var options2 = InterpreterOptions()
XCTAssertEqual(options1, options2)
options1.threadCount = 2
options2.threadCount = 2
XCTAssertEqual(options1, options2)
options2.threadCount = 3
XCTAssertNotEqual(options1, options2)
}
}

View File

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
@testable import TensorFlowLite
class InterpreterTests: XCTestCase {
var interpreter: Interpreter!
@ -31,55 +32,56 @@ class InterpreterTests: XCTestCase {
super.tearDown()
}
func testInterpreter_InitWithModelPath() {
func testInitWithModelPath() {
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
}
func testInterpreter_Init_ThrowsFailedToLoadModel() {
func testInit_ThrowsFailedToLoadModel() {
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
}
}
func testInterpreter_InitWithModelPathAndOptions() {
var options = InterpreterOptions()
func testInitWithModelPathAndOptions() throws {
var options = Interpreter.Options()
options.threadCount = 2
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options))
let interpreter = try Interpreter(modelPath: AddModel.path, options: options)
XCTAssertNotNil(interpreter.options)
}
func testInterpreter_InputTensorCount() {
func testInputTensorCount() {
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
}
func testInterpreter_OutputTensorCount() {
func testOutputTensorCount() {
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
}
func testInterpreter_Invoke() throws {
func testInvoke() throws {
try interpreter.allocateTensors()
XCTAssertNoThrow(try interpreter.invoke())
}
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
func testInvoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
XCTAssertThrowsError(try interpreter.invoke()) { error in
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
}
}
func testInterpreter_InputTensorAtIndex() throws {
func testInputTensorAtIndex() throws {
try setUpAddModelInputTensor()
let inputTensor = try interpreter.input(at: AddModel.validIndex)
XCTAssertEqual(inputTensor, AddModel.inputTensor)
}
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws {
func testInputTensorAtIndex_QuantizedModel() throws {
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
try setUpAddQuantizedModelInputTensor()
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
}
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws {
func testInputTensorAtIndex_ThrowsInvalidIndex() throws {
try interpreter.allocateTensors()
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
let maxIndex = AddModel.inputTensorCount - 1
@ -90,13 +92,13 @@ class InterpreterTests: XCTestCase {
}
}
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() {
func testInputTensorAtIndex_ThrowsAllocateTensorsRequired() {
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
}
}
func testInterpreter_OutputTensorAtIndex() throws {
func testOutputTensorAtIndex() throws {
try setUpAddModelInputTensor()
try interpreter.invoke()
let outputTensor = try interpreter.output(at: AddModel.validIndex)
@ -105,7 +107,7 @@ class InterpreterTests: XCTestCase {
XCTAssertEqual(expectedResults, AddModel.results)
}
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws {
func testOutputTensorAtIndex_QuantizedModel() throws {
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
try setUpAddQuantizedModelInputTensor()
try interpreter.invoke()
@ -115,7 +117,7 @@ class InterpreterTests: XCTestCase {
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
}
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws {
func testOutputTensorAtIndex_ThrowsInvalidIndex() throws {
try interpreter.allocateTensors()
try interpreter.invoke()
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
@ -127,18 +129,18 @@ class InterpreterTests: XCTestCase {
}
}
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
func testOutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
}
}
func testInterpreter_ResizeInputTensorAtIndexToShape() {
func testResizeInputTensorAtIndexToShape() {
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
XCTAssertNoThrow(try interpreter.allocateTensors())
}
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
XCTAssertThrowsError(try interpreter.resizeInput(
at: AddModel.invalidIndex,
to: [2, 2, 3]
@ -151,14 +153,14 @@ class InterpreterTests: XCTestCase {
}
}
func testInterpreter_CopyDataToInputTensorAtIndex() throws {
func testCopyDataToInputTensorAtIndex() throws {
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
try interpreter.allocateTensors()
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
XCTAssertEqual(inputTensor.data, AddModel.inputData)
}
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
XCTAssertThrowsError(try interpreter.copy(
AddModel.inputData,
toInputAt: AddModel.invalidIndex
@ -171,7 +173,7 @@ class InterpreterTests: XCTestCase {
}
}
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
func testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
try interpreter.allocateTensors()
let invalidData = Data(count: AddModel.dataCount - 1)
@ -186,7 +188,7 @@ class InterpreterTests: XCTestCase {
}
}
func testInterpreter_AllocateTensors() {
func testAllocateTensors() {
XCTAssertNoThrow(try interpreter.allocateTensors())
}
@ -215,6 +217,33 @@ class InterpreterTests: XCTestCase {
}
}
class InterpreterOptionsTests: XCTestCase {
func testInitWithDefaultValues() {
let options = Interpreter.Options()
XCTAssertNil(options.threadCount)
}
func testInitWithCustomValues() {
var options = Interpreter.Options()
options.threadCount = 2
XCTAssertEqual(options.threadCount, 2)
}
func testEquatable() {
var options1 = Interpreter.Options()
var options2 = Interpreter.Options()
XCTAssertEqual(options1, options2)
options1.threadCount = 2
options2.threadCount = 2
XCTAssertEqual(options1, options2)
options2.threadCount = 3
XCTAssertNotEqual(options1, options2)
}
}
// MARK: - Constants
/// Values for the `add.bin` model.
@ -224,7 +253,7 @@ private enum AddModel {
static let outputTensorCount = 1
static let invalidIndex = 1
static let validIndex = 0
static let shape: TensorShape = [2]
static let shape: Tensor.Shape = [2]
static let dataCount = inputData.count
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
@ -254,7 +283,7 @@ private enum AddModel {
private enum AddQuantizedModel {
static let info = (name: "add_quantized", extension: "bin")
static let inputOutputIndex = 0
static let shape: TensorShape = [2]
static let shape: Tensor.Shape = [2]
static let inputData = Data([1, 3])
static let outputData = Data([3, 9])
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)

View File

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
@testable import TensorFlowLite
class ModelTests: XCTestCase {
var modelPath: String!
@ -39,15 +40,15 @@ class ModelTests: XCTestCase {
super.tearDown()
}
func testModel_InitWithFilePath() {
func testInitWithFilePath() {
XCTAssertNotNil(Model(filePath: modelPath))
}
func testModel_InitWithEmptyFilePath_FailsInitialization() {
func testInitWithEmptyFilePath_FailsInitialization() {
XCTAssertNil(Model(filePath: ""))
}
func testModel_InitWithInvalidFilePath_FailsInitialization() {
func testInitWithInvalidFilePath_FailsInitialization() {
XCTAssertNil(Model(filePath: "invalid/path"))
}
}

View File

@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
@testable import TensorFlowLite
class QuantizationParametersTests: XCTestCase {
func testQuantizationParameters_InitWithCustomValues() {
func testInitWithCustomValues() {
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
XCTAssertEqual(parameters.scale, 0.5)
XCTAssertEqual(parameters.zeroPoint, 1)
}
func testQuantizationParameters_Equatable() {
func testEquatable() {
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
XCTAssertEqual(parameters1, parameters2)

View File

@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
@testable import TensorFlowLite
class TensorFlowLiteTests: XCTestCase {
func testTensorFlowLite_Runtime_version() {
func testRuntime_Version() {
#if swift(>=5.0)
let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?$"#
#else

View File

@ -12,17 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
@testable import TensorFlowLite
class TensorTests: XCTestCase {
// MARK: - Tensor
func testTensor_Init() {
func testInit() {
let name = "InputTensor"
let dataType: TensorDataType = .uInt8
let shape = TensorShape(Constant.dimensions)
let dataType: Tensor.DataType = .uInt8
let shape = Tensor.Shape(Constant.dimensions)
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let inputTensor = Tensor(
@ -39,10 +38,10 @@ class TensorTests: XCTestCase {
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
}
func testTensor_Equatable() {
func testEquatable() {
let name = "Tensor"
let dataType: TensorDataType = .uInt8
let shape = TensorShape(Constant.dimensions)
let dataType: Tensor.DataType = .uInt8
let shape = Tensor.Shape(Constant.dimensions)
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
let tensor1 = Tensor(
@ -70,30 +69,31 @@ class TensorTests: XCTestCase {
)
XCTAssertNotEqual(tensor1, tensor2)
}
}
// MARK: - TensorShape
class TensorShapeTests: XCTestCase {
func testTensorShape_InitWithArray() {
let shape = TensorShape(Constant.dimensions)
func testInitWithArray() {
let shape = Tensor.Shape(Constant.dimensions)
XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions)
}
func testTensorShape_InitWithElements() {
let shape = TensorShape(2, 2, 3)
func testInitWithElements() {
let shape = Tensor.Shape(2, 2, 3)
XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions)
}
func testTensorShape_InitWithArrayLiteral() {
let shape: TensorShape = [2, 2, 3]
func testInitWithArrayLiteral() {
let shape: Tensor.Shape = [2, 2, 3]
XCTAssertEqual(shape.rank, Constant.dimensions.count)
XCTAssertEqual(shape.dimensions, Constant.dimensions)
}
func testTensorShape_Equatable() {
let shape1 = TensorShape(2, 2, 3)
var shape2: TensorShape = [2, 2, 3]
func testEquatable() {
let shape1 = Tensor.Shape(2, 2, 3)
var shape2: Tensor.Shape = [2, 2, 3]
XCTAssertEqual(shape1, shape2)
shape2 = [2, 2, 4]