STT-tensorflow/tensorflow/lite/experimental/swift/Tests/InterpreterTests.swift
A. Unique TensorFlower 37c101d175 Updates the TensorFlow Lite Swift library to Swift 5.
PiperOrigin-RevId: 245493147
2019-04-26 16:31:31 -07:00

322 lines
11 KiB
Swift

// Copyright 2018 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
@testable import TensorFlowLite
import XCTest
class InterpreterTests: XCTestCase {
var interpreter: Interpreter!
override func setUp() {
super.setUp()
interpreter = try! Interpreter(modelPath: AddModel.path)
}
override func tearDown() {
interpreter = nil
super.tearDown()
}
func testInterpreter_InitWithModelPath() {
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path))
}
func testInterpreter_Init_ThrowsFailedToLoadModel() {
XCTAssertThrowsError(try Interpreter(modelPath: "/invalid/path")) { error in
self.assertEqualErrors(actual: error, expected: .failedToLoadModel)
}
}
func testInterpreter_InitWithModelPathAndOptions() {
var options = InterpreterOptions()
options.threadCount = 2
XCTAssertNoThrow(try Interpreter(modelPath: AddModel.path, options: options))
}
func testInterpreter_InputTensorCount() {
XCTAssertEqual(interpreter.inputTensorCount, AddModel.inputTensorCount)
}
func testInterpreter_OutputTensorCount() {
XCTAssertEqual(interpreter.outputTensorCount, AddModel.outputTensorCount)
}
func testInterpreter_Invoke() throws {
try interpreter.allocateTensors()
XCTAssertNoThrow(try interpreter.invoke())
}
func testInterpreter_Invoke_ThrowsAllocateTensorsRequired_ModelNotReady() {
XCTAssertThrowsError(try interpreter.invoke()) { error in
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
}
}
func testInterpreter_InputTensorAtIndex() throws {
try setUpAddModelInputTensor()
let inputTensor = try interpreter.input(at: AddModel.validIndex)
XCTAssertEqual(inputTensor, AddModel.inputTensor)
}
func testInterpreter_InputTensorAtIndex_QuantizedModel() throws {
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
try setUpAddQuantizedModelInputTensor()
let inputTensor = try interpreter.input(at: AddQuantizedModel.inputOutputIndex)
XCTAssertEqual(inputTensor, AddQuantizedModel.inputTensor)
}
func testInterpreter_InputTensorAtIndex_ThrowsInvalidIndex() throws {
try interpreter.allocateTensors()
XCTAssertThrowsError(try interpreter.input(at: AddModel.invalidIndex)) { error in
let maxIndex = AddModel.inputTensorCount - 1
self.assertEqualErrors(
actual: error,
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
)
}
}
func testInterpreter_InputTensorAtIndex_ThrowsAllocateTensorsRequired() {
XCTAssertThrowsError(try interpreter.input(at: AddModel.validIndex)) { error in
self.assertEqualErrors(actual: error, expected: .allocateTensorsRequired)
}
}
func testInterpreter_OutputTensorAtIndex() throws {
try setUpAddModelInputTensor()
try interpreter.invoke()
let outputTensor = try interpreter.output(at: AddModel.validIndex)
XCTAssertEqual(outputTensor, AddModel.outputTensor)
let expectedResults = [Float32](unsafeData: outputTensor.data)
XCTAssertEqual(expectedResults, AddModel.results)
}
func testInterpreter_OutputTensorAtIndex_QuantizedModel() throws {
interpreter = try Interpreter(modelPath: AddQuantizedModel.path)
try setUpAddQuantizedModelInputTensor()
try interpreter.invoke()
let outputTensor = try interpreter.output(at: AddQuantizedModel.inputOutputIndex)
XCTAssertEqual(outputTensor, AddQuantizedModel.outputTensor)
let expectedResults = [UInt8](outputTensor.data)
XCTAssertEqual(expectedResults, AddQuantizedModel.results)
}
func testInterpreter_OutputTensorAtIndex_ThrowsInvalidIndex() throws {
try interpreter.allocateTensors()
try interpreter.invoke()
XCTAssertThrowsError(try interpreter.output(at: AddModel.invalidIndex)) { error in
let maxIndex = AddModel.outputTensorCount - 1
self.assertEqualErrors(
actual: error,
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
)
}
}
func testInterpreter_OutputTensorAtIndex_ThrowsInvokeInterpreterRequired() {
XCTAssertThrowsError(try interpreter.output(at: AddModel.validIndex)) { error in
self.assertEqualErrors(actual: error, expected: .invokeInterpreterRequired)
}
}
func testInterpreter_ResizeInputTensorAtIndexToShape() {
XCTAssertNoThrow(try interpreter.resizeInput(at: AddModel.validIndex, to: [2, 2, 3]))
XCTAssertNoThrow(try interpreter.allocateTensors())
}
func testInterpreter_ResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
XCTAssertThrowsError(try interpreter.resizeInput(
at: AddModel.invalidIndex,
to: [2, 2, 3]
)) { error in
let maxIndex = AddModel.inputTensorCount - 1
self.assertEqualErrors(
actual: error,
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
)
}
}
func testInterpreter_CopyDataToInputTensorAtIndex() throws {
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
try interpreter.allocateTensors()
let inputTensor = try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
XCTAssertEqual(inputTensor.data, AddModel.inputData)
}
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
XCTAssertThrowsError(try interpreter.copy(
AddModel.inputData,
toInputAt: AddModel.invalidIndex
)) { error in
let maxIndex = AddModel.inputTensorCount - 1
self.assertEqualErrors(
actual: error,
expected: .invalidTensorIndex(index: AddModel.invalidIndex, maxIndex: maxIndex)
)
}
}
func testInterpreter_CopyDataToInputTensorAtIndex_ThrowsInvalidDataCount() throws {
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
try interpreter.allocateTensors()
let invalidData = Data(count: AddModel.dataCount - 1)
XCTAssertThrowsError(try interpreter.copy(
invalidData,
toInputAt: AddModel.validIndex
)) { error in
self.assertEqualErrors(
actual: error,
expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount)
)
}
}
func testInterpreter_AllocateTensors() {
XCTAssertNoThrow(try interpreter.allocateTensors())
}
// MARK: - Private
private func setUpAddModelInputTensor() throws {
precondition(interpreter != nil)
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
try interpreter.allocateTensors()
try interpreter.copy(AddModel.inputData, toInputAt: AddModel.validIndex)
}
private func setUpAddQuantizedModelInputTensor() throws {
precondition(interpreter != nil)
try interpreter.resizeInput(at: AddQuantizedModel.inputOutputIndex, to: AddQuantizedModel.shape)
try interpreter.allocateTensors()
try interpreter.copy(AddQuantizedModel.inputData, toInputAt: AddQuantizedModel.inputOutputIndex)
}
private func assertEqualErrors(actual: Error, expected: InterpreterError) {
guard let actual = actual as? InterpreterError else {
XCTFail("Actual error should be of type InterpreterError.")
return
}
XCTAssertEqual(actual, expected)
}
}
// MARK: - Constants
/// Values for the `add.bin` model.
private enum AddModel {
static let info = (name: "add", extension: "bin")
static let inputTensorCount = 1
static let outputTensorCount = 1
static let invalidIndex = 1
static let validIndex = 0
static let shape: TensorShape = [2]
static let dataCount = inputData.count
static let inputData = Data(copyingBufferOf: [Float32(1.0), Float32(3.0)])
static let outputData = Data(copyingBufferOf: [Float32(3.0), Float32(9.0)])
static let results = [Float32(3.0), Float32(9.0)]
static let inputTensor = Tensor(
name: "input",
dataType: .float32,
shape: shape,
data: inputData
)
static let outputTensor = Tensor(
name: "output",
dataType: .float32,
shape: shape,
data: outputData
)
static var path: String = {
let bundle = Bundle(for: InterpreterTests.self)
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
return path
}()
}
/// Values for the `add_quantized.bin` model.
private enum AddQuantizedModel {
static let info = (name: "add_quantized", extension: "bin")
static let inputOutputIndex = 0
static let shape: TensorShape = [2]
static let inputData = Data([1, 3])
static let outputData = Data([3, 9])
static let quantizationParameters = QuantizationParameters(scale: 0.003922, zeroPoint: 0)
static let results: [UInt8] = [3, 9]
static let inputTensor = Tensor(
name: "input",
dataType: .uInt8,
shape: shape,
data: inputData,
quantizationParameters: quantizationParameters
)
static let outputTensor = Tensor(
name: "output",
dataType: .uInt8,
shape: shape,
data: outputData,
quantizationParameters: quantizationParameters
)
static var path: String = {
let bundle = Bundle(for: InterpreterTests.self)
guard let path = bundle.path(forResource: info.name, ofType: info.extension) else { return "" }
return path
}()
}
// MARK: - Extensions
extension Array {
/// Creates a new array from the bytes of the given unsafe data.
///
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in
/// the `unsafeData`'s buffer to a new array returns an unsafe copy.
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of
/// `MemoryLayout<Element>.stride`.
/// - Parameter unsafeData: The data containing the bytes to turn into an array.
init?(unsafeData: Data) {
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
#if swift(>=5.0)
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
#else
self = unsafeData.withUnsafeBytes {
.init(UnsafeBufferPointer<Element>(
start: $0,
count: unsafeData.count / MemoryLayout<Element>.stride
))
}
#endif // swift(>=5.0)
}
}
extension Data {
/// Creates a new buffer by copying the buffer pointer of the given array.
///
/// - Warning: The given array's element type `T` must be trivial in that it can be copied bit
/// for bit with no indirection or reference-counting operations; otherwise, reinterpreting
/// data from the resulting buffer has undefined behavior.
/// - Parameter array: An array with elements of type `T`.
init<T>(copyingBufferOf array: [T]) {
self = array.withUnsafeBufferPointer(Data.init)
}
}