Updates InterpreterOptions, TensorShape, and TensorDataType to be nested types.
PiperOrigin-RevId: 265160862
This commit is contained in:
parent
2fce6eaf6d
commit
c3a9eb7c28
@ -16,9 +16,9 @@ import TensorFlowLiteC
|
|||||||
|
|
||||||
/// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations.
|
/// A delegate that the `Interpreter` uses to perform TensorFlow Lite model computations.
|
||||||
public protocol Delegate: class {
|
public protocol Delegate: class {
|
||||||
/// `TFL_Delegate` C pointer type.
|
/// The `TfLiteDelegate` C pointer type.
|
||||||
typealias CDelegate = OpaquePointer
|
typealias CDelegate = UnsafeMutablePointer<TfLiteDelegate>
|
||||||
|
|
||||||
/// Delegate that performs model computations.
|
/// The delegate that performs model computations.
|
||||||
var cDelegate: CDelegate? { get }
|
var cDelegate: CDelegate { get }
|
||||||
}
|
}
|
||||||
|
@ -17,32 +17,36 @@ import TensorFlowLiteC
|
|||||||
|
|
||||||
/// A TensorFlow Lite interpreter that performs inference from a given model.
|
/// A TensorFlow Lite interpreter that performs inference from a given model.
|
||||||
public final class Interpreter {
|
public final class Interpreter {
|
||||||
/// `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TFL_Interpreter>`.
|
/// The configuration options for the `Interpreter`.
|
||||||
private typealias CInterpreter = OpaquePointer
|
public let options: Options?
|
||||||
|
|
||||||
/// Total number of input tensors associated with the model.
|
/// The total number of input tensors associated with the model.
|
||||||
public var inputTensorCount: Int {
|
public var inputTensorCount: Int {
|
||||||
return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter))
|
return Int(TfLiteInterpreterGetInputTensorCount(cInterpreter))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Total number of output tensors associated with the model.
|
/// The total number of output tensors associated with the model.
|
||||||
public var outputTensorCount: Int {
|
public var outputTensorCount: Int {
|
||||||
return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter))
|
return Int(TfLiteInterpreterGetOutputTensorCount(cInterpreter))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Underlying `TfLiteInterpreter` C pointer.
|
/// The `TfLiteInterpreter` C pointer type represented as an `UnsafePointer<TfLiteInterpreter>`.
|
||||||
|
private typealias CInterpreter = OpaquePointer
|
||||||
|
|
||||||
|
/// The underlying `TfLiteInterpreter` C pointer.
|
||||||
private var cInterpreter: CInterpreter?
|
private var cInterpreter: CInterpreter?
|
||||||
|
|
||||||
/// Creates a new model interpreter instance.
|
/// Creates a new instance with the given values.
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - modelPath: Local file path to a TensorFlow Lite model.
|
/// - modelPath: Local file path to a TensorFlow Lite model.
|
||||||
/// - options: Custom configurations for the interpreter. Default is `nil` indicating that the
|
/// - options: Custom configuration options for the interpreter. Default is `nil` indicating
|
||||||
/// interpreter will determine the configuration options.
|
/// that the interpreter will determine the configuration options.
|
||||||
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
/// - Throws: An error if the model could not be loaded or the interpreter could not be created.
|
||||||
public init(modelPath: String, options: InterpreterOptions? = nil) throws {
|
public init(modelPath: String, options: Options? = nil) throws {
|
||||||
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
guard let model = Model(filePath: modelPath) else { throw InterpreterError.failedToLoadModel }
|
||||||
|
|
||||||
|
self.options = options
|
||||||
let cInterpreterOptions: OpaquePointer? = try options.map { options in
|
let cInterpreterOptions: OpaquePointer? = try options.map { options in
|
||||||
guard let cOptions = TfLiteInterpreterOptionsCreate() else {
|
guard let cOptions = TfLiteInterpreterOptionsCreate() else {
|
||||||
throw InterpreterError.failedToCreateInterpreter
|
throw InterpreterError.failedToCreateInterpreter
|
||||||
@ -105,14 +109,14 @@ public final class Interpreter {
|
|||||||
else {
|
else {
|
||||||
throw InterpreterError.allocateTensorsRequired
|
throw InterpreterError.allocateTensorsRequired
|
||||||
}
|
}
|
||||||
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else {
|
guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
|
||||||
throw InterpreterError.invalidTensorDataType
|
throw InterpreterError.invalidTensorDataType
|
||||||
}
|
}
|
||||||
|
|
||||||
let name = String(cString: nameCString)
|
let name = String(cString: nameCString)
|
||||||
let rank = TfLiteTensorNumDims(cTensor)
|
let rank = TfLiteTensorNumDims(cTensor)
|
||||||
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
|
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
|
||||||
let shape = TensorShape(dimensions)
|
let shape = Tensor.Shape(dimensions)
|
||||||
let byteCount = TfLiteTensorByteSize(cTensor)
|
let byteCount = TfLiteTensorByteSize(cTensor)
|
||||||
let data = Data(bytes: bytes, count: byteCount)
|
let data = Data(bytes: bytes, count: byteCount)
|
||||||
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
|
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
|
||||||
@ -151,14 +155,14 @@ public final class Interpreter {
|
|||||||
else {
|
else {
|
||||||
throw InterpreterError.invokeInterpreterRequired
|
throw InterpreterError.invokeInterpreterRequired
|
||||||
}
|
}
|
||||||
guard let dataType = TensorDataType(type: TfLiteTensorType(cTensor)) else {
|
guard let dataType = Tensor.DataType(type: TfLiteTensorType(cTensor)) else {
|
||||||
throw InterpreterError.invalidTensorDataType
|
throw InterpreterError.invalidTensorDataType
|
||||||
}
|
}
|
||||||
|
|
||||||
let name = String(cString: nameCString)
|
let name = String(cString: nameCString)
|
||||||
let rank = TfLiteTensorNumDims(cTensor)
|
let rank = TfLiteTensorNumDims(cTensor)
|
||||||
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
|
let dimensions = (0..<rank).map { Int(TfLiteTensorDim(cTensor, $0)) }
|
||||||
let shape = TensorShape(dimensions)
|
let shape = Tensor.Shape(dimensions)
|
||||||
let byteCount = TfLiteTensorByteSize(cTensor)
|
let byteCount = TfLiteTensorByteSize(cTensor)
|
||||||
let data = Data(bytes: bytes, count: byteCount)
|
let data = Data(bytes: bytes, count: byteCount)
|
||||||
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
|
let cQuantizationParams = TfLiteTensorQuantizationParams(cTensor)
|
||||||
@ -187,7 +191,7 @@ public final class Interpreter {
|
|||||||
/// - index: The index for the input tensor.
|
/// - index: The index for the input tensor.
|
||||||
/// - shape: The shape that the input tensor should be resized to.
|
/// - shape: The shape that the input tensor should be resized to.
|
||||||
/// - Throws: An error if the input tensor at the given index could not be resized.
|
/// - Throws: An error if the input tensor at the given index could not be resized.
|
||||||
public func resizeInput(at index: Int, to shape: TensorShape) throws {
|
public func resizeInput(at index: Int, to shape: Tensor.Shape) throws {
|
||||||
let maxIndex = inputTensorCount - 1
|
let maxIndex = inputTensorCount - 1
|
||||||
guard case 0...maxIndex = index else {
|
guard case 0...maxIndex = index else {
|
||||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||||
@ -237,7 +241,7 @@ public final class Interpreter {
|
|||||||
return try input(at: index)
|
return try input(at: index)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allocates memory for all input tensors based on their `TensorShape`s.
|
/// Allocates memory for all input tensors based on their `Tensor.Shape`s.
|
||||||
///
|
///
|
||||||
/// - Note: This is a relatively expensive operation and should only be called after creating the
|
/// - Note: This is a relatively expensive operation and should only be called after creating the
|
||||||
/// interpreter and/or resizing any input tensors.
|
/// interpreter and/or resizing any input tensors.
|
||||||
@ -249,7 +253,17 @@ public final class Interpreter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Extensions
|
extension Interpreter {
|
||||||
|
/// Options for configuring the `Interpreter`.
|
||||||
|
public struct Options: Equatable, Hashable {
|
||||||
|
/// The maximum number of CPU threads that the interpreter should run on. Default is `nil`
|
||||||
|
/// indicating that the `Interpreter` will decide the number of threads to use.
|
||||||
|
public var threadCount: Int? = nil
|
||||||
|
|
||||||
|
/// Creates a new instance with the default values.
|
||||||
|
public init() {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extension String {
|
extension String {
|
||||||
/// Returns a new `String` initialized by using the given format C array as a template into which
|
/// Returns a new `String` initialized by using the given format C array as a template into which
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
|
|
||||||
/// TensorFlow Lite interpreter errors.
|
/// Errors thrown by the TensorFlow Lite `Interpreter`.
|
||||||
public enum InterpreterError: Error, Equatable, Hashable {
|
public enum InterpreterError: Error, Equatable, Hashable {
|
||||||
case invalidTensorIndex(index: Int, maxIndex: Int)
|
case invalidTensorIndex(index: Int, maxIndex: Int)
|
||||||
case invalidTensorDataCount(provided: Int, required: Int)
|
case invalidTensorDataCount(provided: Int, required: Int)
|
||||||
@ -29,10 +29,8 @@ public enum InterpreterError: Error, Equatable, Hashable {
|
|||||||
case tensorFlowLiteError(String)
|
case tensorFlowLiteError(String)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MARK: - Extensions
|
|
||||||
|
|
||||||
extension InterpreterError: LocalizedError {
|
extension InterpreterError: LocalizedError {
|
||||||
/// Localized description of the interpreter error.
|
/// A localized description of the interpreter error.
|
||||||
public var errorDescription: String? {
|
public var errorDescription: String? {
|
||||||
switch self {
|
switch self {
|
||||||
case .invalidTensorIndex(let index, let maxIndex):
|
case .invalidTensorIndex(let index, let maxIndex):
|
||||||
@ -62,6 +60,6 @@ extension InterpreterError: LocalizedError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
extension InterpreterError: CustomStringConvertible {
|
extension InterpreterError: CustomStringConvertible {
|
||||||
/// Textual representation of the TensorFlow Lite interpreter error.
|
/// A textual representation of the TensorFlow Lite interpreter error.
|
||||||
public var description: String { return errorDescription ?? "Unknown error." }
|
public var description: String { return errorDescription ?? "Unknown error." }
|
||||||
}
|
}
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
// Copyright 2018 Google Inc. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at:
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
/// Custom configuration options for a TensorFlow Lite `Interpreter`.
|
|
||||||
public struct InterpreterOptions: Equatable {
|
|
||||||
/// Maximum number of CPU threads that the interpreter should run on. Default is `nil` indicating
|
|
||||||
/// that the `Interpreter` will decide the number of threads to use.
|
|
||||||
public var threadCount: Int? = nil
|
|
||||||
|
|
||||||
/// Creates a new instance of interpreter options.
|
|
||||||
public init() {}
|
|
||||||
}
|
|
@ -14,15 +14,15 @@
|
|||||||
|
|
||||||
import TensorFlowLiteC
|
import TensorFlowLiteC
|
||||||
|
|
||||||
/// A TensorFlow Lite model used by the 'Interpreter` to perform inference.
|
/// A TensorFlow Lite model used by the `Interpreter` to perform inference.
|
||||||
final class Model {
|
final class Model {
|
||||||
/// `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`.
|
/// The `TfLiteModel` C pointer type represented as an `UnsafePointer<TfLiteModel>`.
|
||||||
typealias CModel = OpaquePointer
|
typealias CModel = OpaquePointer
|
||||||
|
|
||||||
/// Underlying `TfLiteModel` C pointer.
|
/// The underlying `TfLiteModel` C pointer.
|
||||||
let cModel: CModel?
|
let cModel: CModel?
|
||||||
|
|
||||||
/// Creates a new model instance.
|
/// Creates a new instance with the given `filePath`.
|
||||||
///
|
///
|
||||||
/// - Precondition: Initialization can fail if the given `filePath` is invalid.
|
/// - Precondition: Initialization can fail if the given `filePath` is invalid.
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
|
@ -16,11 +16,11 @@
|
|||||||
/// be mapped to float values using the following conversion:
|
/// be mapped to float values using the following conversion:
|
||||||
/// `realValue = scale * (quantizedValue - zeroPoint)`.
|
/// `realValue = scale * (quantizedValue - zeroPoint)`.
|
||||||
public struct QuantizationParameters: Equatable, Hashable {
|
public struct QuantizationParameters: Equatable, Hashable {
|
||||||
/// Difference between real values corresponding to consecutive quantized values differing by 1.
|
/// The difference between real values corresponding to consecutive quantized values differing by
|
||||||
/// For example, the range of quantized values for `UInt8` data type is [0, 255].
|
/// 1. For example, the range of quantized values for `UInt8` data type is [0, 255].
|
||||||
public let scale: Float
|
public let scale: Float
|
||||||
|
|
||||||
/// Quantized value that corresponds to the real 0 value.
|
/// The quantized value that corresponds to the real 0 value.
|
||||||
public let zeroPoint: Int
|
public let zeroPoint: Int
|
||||||
|
|
||||||
/// Creates a new quantization parameters instance.
|
/// Creates a new quantization parameters instance.
|
||||||
|
@ -17,19 +17,19 @@ import TensorFlowLiteC
|
|||||||
|
|
||||||
/// An input or output tensor in a TensorFlow Lite graph.
|
/// An input or output tensor in a TensorFlow Lite graph.
|
||||||
public struct Tensor: Equatable, Hashable {
|
public struct Tensor: Equatable, Hashable {
|
||||||
/// Name of the tensor.
|
/// The name of the tensor.
|
||||||
public let name: String
|
public let name: String
|
||||||
|
|
||||||
/// Data type of the tensor.
|
/// The data type of the tensor.
|
||||||
public let dataType: TensorDataType
|
public let dataType: DataType
|
||||||
|
|
||||||
/// Shape of the tensor.
|
/// The shape of the tensor.
|
||||||
public let shape: TensorShape
|
public let shape: Shape
|
||||||
|
|
||||||
/// Data in the input or output tensor.
|
/// The data in the input or output tensor.
|
||||||
public let data: Data
|
public let data: Data
|
||||||
|
|
||||||
/// Quantization parameters for the tensor if using a quantized model.
|
/// The quantization parameters for the tensor if using a quantized model.
|
||||||
public let quantizationParameters: QuantizationParameters?
|
public let quantizationParameters: QuantizationParameters?
|
||||||
|
|
||||||
/// Creates a new input or output tensor instance.
|
/// Creates a new input or output tensor instance.
|
||||||
@ -43,8 +43,8 @@ public struct Tensor: Equatable, Hashable {
|
|||||||
/// Default is `nil`.
|
/// Default is `nil`.
|
||||||
init(
|
init(
|
||||||
name: String,
|
name: String,
|
||||||
dataType: TensorDataType,
|
dataType: DataType,
|
||||||
shape: TensorShape,
|
shape: Shape,
|
||||||
data: Data,
|
data: Data,
|
||||||
quantizationParameters: QuantizationParameters? = nil
|
quantizationParameters: QuantizationParameters? = nil
|
||||||
) {
|
) {
|
||||||
@ -56,83 +56,86 @@ public struct Tensor: Equatable, Hashable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Supported TensorFlow Lite tensor data types.
|
extension Tensor {
|
||||||
public enum TensorDataType: Equatable, Hashable {
|
/// The supported `Tensor` data types.
|
||||||
/// Boolean.
|
public enum DataType: Equatable, Hashable {
|
||||||
case bool
|
/// A boolean.
|
||||||
/// 8-bit unsigned integer.
|
case bool
|
||||||
case uInt8
|
/// An 8-bit unsigned integer.
|
||||||
/// 16-bit signed integer.
|
case uInt8
|
||||||
case int16
|
/// A 16-bit signed integer.
|
||||||
/// 32-bit signed integer.
|
case int16
|
||||||
case int32
|
/// A 32-bit signed integer.
|
||||||
/// 64-bit signed integer.
|
case int32
|
||||||
case int64
|
/// A 64-bit signed integer.
|
||||||
/// 16-bit half precision floating point.
|
case int64
|
||||||
case float16
|
/// A 16-bit half precision floating point.
|
||||||
/// 32-bit single precision floating point.
|
case float16
|
||||||
case float32
|
/// A 32-bit single precision floating point.
|
||||||
|
case float32
|
||||||
|
|
||||||
/// Creates a new tensor data type from the given `TfLiteType` or `nil` if the data type is
|
/// Creates a new instance from the given `TfLiteType` or `nil` if the data type is unsupported
|
||||||
/// unsupported or could not be determined because there was an error.
|
/// or could not be determined because there was an error.
|
||||||
///
|
///
|
||||||
/// - Parameter type: A data type supported by a tensor.
|
/// - Parameter type: Data type supported by a tensor.
|
||||||
init?(type: TfLiteType) {
|
init?(type: TfLiteType) {
|
||||||
switch type {
|
switch type {
|
||||||
case kTfLiteBool:
|
case kTfLiteBool:
|
||||||
self = .bool
|
self = .bool
|
||||||
case kTfLiteUInt8:
|
case kTfLiteUInt8:
|
||||||
self = .uInt8
|
self = .uInt8
|
||||||
case kTfLiteInt16:
|
case kTfLiteInt16:
|
||||||
self = .int16
|
self = .int16
|
||||||
case kTfLiteInt32:
|
case kTfLiteInt32:
|
||||||
self = .int32
|
self = .int32
|
||||||
case kTfLiteInt64:
|
case kTfLiteInt64:
|
||||||
self = .int64
|
self = .int64
|
||||||
case kTfLiteFloat16:
|
case kTfLiteFloat16:
|
||||||
self = .float16
|
self = .float16
|
||||||
case kTfLiteFloat32:
|
case kTfLiteFloat32:
|
||||||
self = .float32
|
self = .float32
|
||||||
case kTfLiteNoType:
|
case kTfLiteNoType:
|
||||||
fallthrough
|
fallthrough
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The shape of a TensorFlow Lite tensor.
|
extension Tensor {
|
||||||
public struct TensorShape: Equatable, Hashable {
|
/// The shape of a `Tensor`.
|
||||||
|
public struct Shape: Equatable, Hashable {
|
||||||
|
/// The number of dimensions of the tensor.
|
||||||
|
public let rank: Int
|
||||||
|
|
||||||
/// The number of dimensions of the tensor.
|
/// An array of dimensions for the tensor.
|
||||||
public let rank: Int
|
public let dimensions: [Int]
|
||||||
|
|
||||||
/// Array of dimensions for the tensor.
|
/// An array of `Int32` dimensions for the tensor.
|
||||||
public let dimensions: [Int]
|
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
|
||||||
|
|
||||||
/// Array of `Int32` dimensions for the tensor.
|
/// Creates a new instance with the given array of dimensions.
|
||||||
var int32Dimensions: [Int32] { return dimensions.map(Int32.init) }
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - dimensions: Dimensions for the tensor.
|
||||||
|
public init(_ dimensions: [Int]) {
|
||||||
|
self.rank = dimensions.count
|
||||||
|
self.dimensions = dimensions
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a new tensor shape instance with the given array of dimensions.
|
/// Creates a new instance with the given elements representing the dimensions.
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - dimensions: Dimensions for the tensor.
|
/// - elements: Dimensions for the tensor.
|
||||||
public init(_ dimensions: [Int]) {
|
public init(_ elements: Int...) {
|
||||||
self.rank = dimensions.count
|
self.init(elements)
|
||||||
self.dimensions = dimensions
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new tensor shape instance with the given elements representing the dimensions.
|
|
||||||
///
|
|
||||||
/// - Parameters:
|
|
||||||
/// - elements: Dimensions for the tensor.
|
|
||||||
public init(_ elements: Int...) {
|
|
||||||
self.init(elements)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension TensorShape: ExpressibleByArrayLiteral {
|
extension Tensor.Shape: ExpressibleByArrayLiteral {
|
||||||
/// Creates a new tensor shape instance with the given array literal representing the dimensions.
|
/// Creates a new instance with the given array literal representing the dimensions.
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - arrayLiteral: Dimensions for the tensor.
|
/// - arrayLiteral: Dimensions for the tensor.
|
||||||
|
@ -12,18 +12,14 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import class TensorFlowLite.Interpreter
|
import TensorFlowLite
|
||||||
import struct TensorFlowLite.InterpreterOptions
|
|
||||||
import struct TensorFlowLite.Tensor
|
|
||||||
import struct TensorFlowLite.TensorShape
|
|
||||||
import enum TensorFlowLite.Runtime
|
|
||||||
import UIKit
|
import UIKit
|
||||||
|
|
||||||
class ViewController: UIViewController {
|
class ViewController: UIViewController {
|
||||||
|
|
||||||
// MARK: - Properties
|
// MARK: - Properties
|
||||||
|
|
||||||
/// TensorFlowLite interpreter object for performing inference from a given model.
|
/// TensorFlow Lite interpreter object for performing inference from a given model.
|
||||||
private var interpreter: Interpreter?
|
private var interpreter: Interpreter?
|
||||||
|
|
||||||
/// Serial dispatch queue for managing `Interpreter` calls.
|
/// Serial dispatch queue for managing `Interpreter` calls.
|
||||||
@ -122,7 +118,7 @@ class ViewController: UIViewController {
|
|||||||
private func setUpInterpreter(withModelPath modelPath: String) {
|
private func setUpInterpreter(withModelPath modelPath: String) {
|
||||||
interpreterQueue.async {
|
interpreterQueue.async {
|
||||||
do {
|
do {
|
||||||
var options = InterpreterOptions()
|
var options = Interpreter.Options()
|
||||||
options.threadCount = 2
|
options.threadCount = 2
|
||||||
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
|
self.interpreter = try Interpreter(modelPath: modelPath, options: options)
|
||||||
} catch let error {
|
} catch let error {
|
||||||
@ -211,7 +207,7 @@ class ViewController: UIViewController {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
do {
|
do {
|
||||||
let shape = TensorShape(2)
|
let shape = Tensor.Shape(2)
|
||||||
try (0..<interpreter.inputTensorCount).forEach { index in
|
try (0..<interpreter.inputTensorCount).forEach { index in
|
||||||
try interpreter.resizeInput(at: index, to: shape)
|
try interpreter.resizeInput(at: index, to: shape)
|
||||||
}
|
}
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
// Copyright 2018 Google Inc. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at:
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
@testable import TensorFlowLite
|
|
||||||
import XCTest
|
|
||||||
|
|
||||||
class InterpreterOptionsTests: XCTestCase {
|
|
||||||
|
|
||||||
func testInterpreterOptions_InitWithDefaultValues() {
|
|
||||||
let options = InterpreterOptions()
|
|
||||||
XCTAssertNil(options.threadCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testInterpreterOptions_InitWithCustomValues() {
|
|
||||||
var options = InterpreterOptions()
|
|
||||||
options.threadCount = 2
|
|
||||||
XCTAssertEqual(options.threadCount, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func testInterpreterOptions_Equatable() {
|
|
||||||
var options1 = InterpreterOptions()
|
|
||||||
var options2 = InterpreterOptions()
|
|
||||||
XCTAssertEqual(options1, options2)
|
|
||||||
|
|
||||||
options1.threadCount = 2
|
|
||||||
options2.threadCount = 2
|
|
||||||
XCTAssertEqual(options1, options2)
|
|
||||||
|
|
||||||
options2.threadCount = 3
|
|
||||||
XCTAssertNotEqual(options1, options2)
|
|
||||||
}
|
|
||||||
}
|
|
@ -12,9 +12,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
@testable import TensorFlowLite
|
|
||||||
import XCTest
|
import XCTest
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
|
||||||
class InterpreterTests: XCTestCase {
|
class InterpreterTests: XCTestCase {
|
||||||
|
|
||||||
var interpreter: Interpreter!
|
var interpreter: Interpreter!
|
||||||
@ -31,55 +32,56 @@ class InterpreterTests: XCTestCase {
|
|||||||
super.tearDown()
|
super.tearDown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_InitWithModelPath() {
|
func testInitWithModelPath() {
|
||||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
|
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_Init_ThrowsFailedToLoadModel() {
|
func testInit_ThrowsFailedToLoadModel() {
|
||||||
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
|
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
|
||||||
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
|
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_InitWithModelPathAndOptions() {
|
func testInitWithModelPathAndOptions() throws {
|
||||||
var options = InterpreterOptions()
|
var options = Interpreter.Options()
|
||||||
options.threadCount = 2
|
options.threadCount = 2
|
||||||
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options))
|
let interpreter = try Interpreter(modelPath: AddModel.path, options: options)
|
||||||
|
XCTAssertNotNil(interpreter.options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_InputTensorCount() {
|
func testInputTensorCount() {
|
||||||
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
|
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_OutputTensorCount() {
|
func testOutputTensorCount() {
|
||||||
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
|
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_Invoke() throws {
|
func testInvoke() throws {
|
||||||
try interpreter.allocateTensors()
|
try interpreter.allocateTensors()
|
||||||
XCTAssertNoThrow(try interpreter.invoke())
|
XCTAssertNoThrow(try interpreter.invoke())
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
|
func testInvoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
|
||||||
XCTAssertThrowsError(try interpreter.invoke()) { error in
|
XCTAssertThrowsError(try interpreter.invoke()) { error in
|
||||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_InputTensorAtIndex() throws {
|
func testInputTensorAtIndex() throws {
|
||||||
try setUpAddModelInputTensor()
|
try setUpAddModelInputTensor()
|
||||||
let inputTensor = try interpreter.input(at: AddModel.validIndex)
|
let inputTensor = try interpreter.input(at: AddModel.validIndex)
|
||||||
XCTAssertEqual(inputTensor, AddModel.inputTensor)
|
XCTAssertEqual(inputTensor, AddModel.inputTensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws {
|
func testInputTensorAtIndex_QuantizedModel() throws {
|
||||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||||
try setUpAddQuantizedModelInputTensor()
|
try setUpAddQuantizedModelInputTensor()
|
||||||
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
|
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
|
||||||
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
|
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws {
|
func testInputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||||
try interpreter.allocateTensors()
|
try interpreter.allocateTensors()
|
||||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
|
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
|
||||||
let maxIndex = AddModel.inputTensorCount - 1
|
let maxIndex = AddModel.inputTensorCount - 1
|
||||||
@ -90,13 +92,13 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() {
|
func testInputTensorAtIndex_ThrowsAllocateTensorsRequired() {
|
||||||
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
|
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
|
||||||
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_OutputTensorAtIndex() throws {
|
func testOutputTensorAtIndex() throws {
|
||||||
try setUpAddModelInputTensor()
|
try setUpAddModelInputTensor()
|
||||||
try interpreter.invoke()
|
try interpreter.invoke()
|
||||||
let outputTensor = try interpreter.output(at: AddModel.validIndex)
|
let outputTensor = try interpreter.output(at: AddModel.validIndex)
|
||||||
@ -105,7 +107,7 @@ class InterpreterTests: XCTestCase {
|
|||||||
XCTAssertEqual(expectedResults, AddModel.results)
|
XCTAssertEqual(expectedResults, AddModel.results)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws {
|
func testOutputTensorAtIndex_QuantizedModel() throws {
|
||||||
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
|
||||||
try setUpAddQuantizedModelInputTensor()
|
try setUpAddQuantizedModelInputTensor()
|
||||||
try interpreter.invoke()
|
try interpreter.invoke()
|
||||||
@ -115,7 +117,7 @@ class InterpreterTests: XCTestCase {
|
|||||||
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
|
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws {
|
func testOutputTensorAtIndex_ThrowsInvalidIndex() throws {
|
||||||
try interpreter.allocateTensors()
|
try interpreter.allocateTensors()
|
||||||
try interpreter.invoke()
|
try interpreter.invoke()
|
||||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
|
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
|
||||||
@ -127,18 +129,18 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
|
func testOutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
|
||||||
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
|
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
|
||||||
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
|
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_ResizeInputTensorAtIndexToShape() {
|
func testResizeInputTensorAtIndexToShape() {
|
||||||
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
|
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
|
||||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
||||||
XCTAssertThrowsError(try interpreter.resizeInput(
|
XCTAssertThrowsError(try interpreter.resizeInput(
|
||||||
at: AddModel.invalidIndex,
|
at: AddModel.invalidIndex,
|
||||||
to: [2, 2, 3]
|
to: [2, 2, 3]
|
||||||
@ -151,14 +153,14 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_CopyDataToInputTensorAtIndex() throws {
|
func testCopyDataToInputTensorAtIndex() throws {
|
||||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||||
try interpreter.allocateTensors()
|
try interpreter.allocateTensors()
|
||||||
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
|
||||||
XCTAssertEqual(inputTensor.data, AddModel.inputData)
|
XCTAssertEqual(inputTensor.data, AddModel.inputData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
||||||
XCTAssertThrowsError(try interpreter.copy(
|
XCTAssertThrowsError(try interpreter.copy(
|
||||||
AddModel.inputData,
|
AddModel.inputData,
|
||||||
toInputAt: AddModel.invalidIndex
|
toInputAt: AddModel.invalidIndex
|
||||||
@ -171,7 +173,7 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
|
func testCopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
|
||||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||||
try interpreter.allocateTensors()
|
try interpreter.allocateTensors()
|
||||||
let invalidData = Data(count: AddModel.dataCount - 1)
|
let invalidData = Data(count: AddModel.dataCount - 1)
|
||||||
@ -186,7 +188,7 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInterpreter_AllocateTensors() {
|
func testAllocateTensors() {
|
||||||
XCTAssertNoThrow(try interpreter.allocateTensors())
|
XCTAssertNoThrow(try interpreter.allocateTensors())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,6 +217,33 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class InterpreterOptionsTests: XCTestCase {
|
||||||
|
|
||||||
|
func testInitWithDefaultValues() {
|
||||||
|
let options = Interpreter.Options()
|
||||||
|
XCTAssertNil(options.threadCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testInitWithCustomValues() {
|
||||||
|
var options = Interpreter.Options()
|
||||||
|
options.threadCount = 2
|
||||||
|
XCTAssertEqual(options.threadCount, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testEquatable() {
|
||||||
|
var options1 = Interpreter.Options()
|
||||||
|
var options2 = Interpreter.Options()
|
||||||
|
XCTAssertEqual(options1, options2)
|
||||||
|
|
||||||
|
options1.threadCount = 2
|
||||||
|
options2.threadCount = 2
|
||||||
|
XCTAssertEqual(options1, options2)
|
||||||
|
|
||||||
|
options2.threadCount = 3
|
||||||
|
XCTAssertNotEqual(options1, options2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MARK: - Constants
|
// MARK: - Constants
|
||||||
|
|
||||||
/// Values for the `add.bin` model.
|
/// Values for the `add.bin` model.
|
||||||
@ -224,7 +253,7 @@ private enum AddModel {
|
|||||||
static let outputTensorCount = 1
|
static let outputTensorCount = 1
|
||||||
static let invalidIndex = 1
|
static let invalidIndex = 1
|
||||||
static let validIndex = 0
|
static let validIndex = 0
|
||||||
static let shape: TensorShape = [2]
|
static let shape: Tensor.Shape = [2]
|
||||||
static let dataCount = inputData.count
|
static let dataCount = inputData.count
|
||||||
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
|
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
|
||||||
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
|
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
|
||||||
@ -254,7 +283,7 @@ private enum AddModel {
|
|||||||
private enum AddQuantizedModel {
|
private enum AddQuantizedModel {
|
||||||
static let info = (name: "add_quantized", extension: "bin")
|
static let info = (name: "add_quantized", extension: "bin")
|
||||||
static let inputOutputIndex = 0
|
static let inputOutputIndex = 0
|
||||||
static let shape: TensorShape = [2]
|
static let shape: Tensor.Shape = [2]
|
||||||
static let inputData = Data([1, 3])
|
static let inputData = Data([1, 3])
|
||||||
static let outputData = Data([3, 9])
|
static let outputData = Data([3, 9])
|
||||||
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
|
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
|
||||||
|
@ -12,9 +12,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
@testable import TensorFlowLite
|
|
||||||
import XCTest
|
import XCTest
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
|
||||||
class ModelTests: XCTestCase {
|
class ModelTests: XCTestCase {
|
||||||
|
|
||||||
var modelPath: String!
|
var modelPath: String!
|
||||||
@ -39,15 +40,15 @@ class ModelTests: XCTestCase {
|
|||||||
super.tearDown()
|
super.tearDown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func testModel_InitWithFilePath() {
|
func testInitWithFilePath() {
|
||||||
XCTAssertNotNil(Model(filePath: modelPath))
|
XCTAssertNotNil(Model(filePath: modelPath))
|
||||||
}
|
}
|
||||||
|
|
||||||
func testModel_InitWithEmptyFilePath_FailsInitialization() {
|
func testInitWithEmptyFilePath_FailsInitialization() {
|
||||||
XCTAssertNil(Model(filePath: ""))
|
XCTAssertNil(Model(filePath: ""))
|
||||||
}
|
}
|
||||||
|
|
||||||
func testModel_InitWithInvalidFilePath_FailsInitialization() {
|
func testInitWithInvalidFilePath_FailsInitialization() {
|
||||||
XCTAssertNil(Model(filePath: "invalid/path"))
|
XCTAssertNil(Model(filePath: "invalid/path"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,18 +12,19 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
@testable import TensorFlowLite
|
|
||||||
import XCTest
|
import XCTest
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
|
||||||
class QuantizationParametersTests: XCTestCase {
|
class QuantizationParametersTests: XCTestCase {
|
||||||
|
|
||||||
func testQuantizationParameters_InitWithCustomValues() {
|
func testInitWithCustomValues() {
|
||||||
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
let parameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
XCTAssertEqual(parameters.scale, 0.5)
|
XCTAssertEqual(parameters.scale, 0.5)
|
||||||
XCTAssertEqual(parameters.zeroPoint, 1)
|
XCTAssertEqual(parameters.zeroPoint, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testQuantizationParameters_Equatable() {
|
func testEquatable() {
|
||||||
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
let parameters1 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
let parameters2 = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
XCTAssertEqual(parameters1, parameters2)
|
XCTAssertEqual(parameters1, parameters2)
|
||||||
|
@ -12,12 +12,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
@testable import TensorFlowLite
|
|
||||||
import XCTest
|
import XCTest
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
|
||||||
class TensorFlowLiteTests: XCTestCase {
|
class TensorFlowLiteTests: XCTestCase {
|
||||||
|
|
||||||
func testTensorFlowLite_Runtime_version() {
|
func testRuntime_Version() {
|
||||||
#if swift(>=5.0)
|
#if swift(>=5.0)
|
||||||
let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?$"#
|
let pattern = #"^(\d+)\.(\d+)\.(\d+)([+-][-.0-9A-Za-z]+)?$"#
|
||||||
#else
|
#else
|
||||||
|
@ -12,17 +12,16 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
@testable import TensorFlowLite
|
|
||||||
import XCTest
|
import XCTest
|
||||||
|
|
||||||
|
@testable import TensorFlowLite
|
||||||
|
|
||||||
class TensorTests: XCTestCase {
|
class TensorTests: XCTestCase {
|
||||||
|
|
||||||
// MARK: - Tensor
|
func testInit() {
|
||||||
|
|
||||||
func testTensor_Init() {
|
|
||||||
let name = "InputTensor"
|
let name = "InputTensor"
|
||||||
let dataType: TensorDataType = .uInt8
|
let dataType: Tensor.DataType = .uInt8
|
||||||
let shape = TensorShape(Constant.dimensions)
|
let shape = Tensor.Shape(Constant.dimensions)
|
||||||
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
||||||
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
let inputTensor = Tensor(
|
let inputTensor = Tensor(
|
||||||
@ -39,10 +38,10 @@ class TensorTests: XCTestCase {
|
|||||||
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
|
XCTAssertEqual(inputTensor.quantizationParameters, quantizationParameters)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTensor_Equatable() {
|
func testEquatable() {
|
||||||
let name = "Tensor"
|
let name = "Tensor"
|
||||||
let dataType: TensorDataType = .uInt8
|
let dataType: Tensor.DataType = .uInt8
|
||||||
let shape = TensorShape(Constant.dimensions)
|
let shape = Tensor.Shape(Constant.dimensions)
|
||||||
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
guard let data = name.data(using: .utf8) else { XCTFail("Data should not be nil."); return }
|
||||||
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
let quantizationParameters = QuantizationParameters(scale: 0.5, zeroPoint: 1)
|
||||||
let tensor1 = Tensor(
|
let tensor1 = Tensor(
|
||||||
@ -70,30 +69,31 @@ class TensorTests: XCTestCase {
|
|||||||
)
|
)
|
||||||
XCTAssertNotEqual(tensor1, tensor2)
|
XCTAssertNotEqual(tensor1, tensor2)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// MARK: - TensorShape
|
class TensorShapeTests: XCTestCase {
|
||||||
|
|
||||||
func testTensorShape_InitWithArray() {
|
func testInitWithArray() {
|
||||||
let shape = TensorShape(Constant.dimensions)
|
let shape = Tensor.Shape(Constant.dimensions)
|
||||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTensorShape_InitWithElements() {
|
func testInitWithElements() {
|
||||||
let shape = TensorShape(2, 2, 3)
|
let shape = Tensor.Shape(2, 2, 3)
|
||||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTensorShape_InitWithArrayLiteral() {
|
func testInitWithArrayLiteral() {
|
||||||
let shape: TensorShape = [2, 2, 3]
|
let shape: Tensor.Shape = [2, 2, 3]
|
||||||
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
XCTAssertEqual(shape.rank, Constant.dimensions.count)
|
||||||
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
XCTAssertEqual(shape.dimensions, Constant.dimensions)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTensorShape_Equatable() {
|
func testEquatable() {
|
||||||
let shape1 = TensorShape(2, 2, 3)
|
let shape1 = Tensor.Shape(2, 2, 3)
|
||||||
var shape2: TensorShape = [2, 2, 3]
|
var shape2: Tensor.Shape = [2, 2, 3]
|
||||||
XCTAssertEqual(shape1, shape2)
|
XCTAssertEqual(shape1, shape2)
|
||||||
|
|
||||||
shape2 = [2, 2, 4]
|
shape2 = [2, 2, 4]
|
||||||
|
Loading…
Reference in New Issue
Block a user