Add flag for using optimized TFLite CPU kernels on iOS
This adds new experimental flags to the interpreter options of TFLite Obj-C and Swift APIs, which can be used for opting in to a set of highly optimized floating point kernels provided via the XNNPACK delegate. The flags can be used as follows. Obj-C: TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init]; options.useXNNPACK = YES; NSError *error; TFLInterpreter *interpreter = [[TFLInterpreter alloc] initWithModelPath:@"model/path" options:options error:&error]; Swift: var options = InterpreterOptions() options.isXNNPackEnabled = true var interpreter = try Interpreter(modelPath: "model/path", options: options) PiperOrigin-RevId: 317270012 Change-Id: I82aae43c3de13ab08af3c70513e2a458e807b0f1
This commit is contained in:
parent
e51b17f458
commit
772433a2a2
@ -14,6 +14,10 @@ EMSCRIPTEN_LINKOPTS = [
|
|||||||
"-s TOTAL_MEMORY=134217728",
|
"-s TOTAL_MEMORY=134217728",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
exports_files([
|
||||||
|
"xnnpack_delegate.h",
|
||||||
|
])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xnnpack_delegate",
|
name = "xnnpack_delegate",
|
||||||
srcs = ["xnnpack_delegate.cc"],
|
srcs = ["xnnpack_delegate.cc"],
|
||||||
|
@ -18,10 +18,26 @@ sh_binary(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# When the static framework is built with bazel, the all header files are moved
|
||||||
|
# to the "Headers" directory with no header path prefixes. This auxiliary rule
|
||||||
|
# is used for stripping the path prefix to the "common.h" file included by the
|
||||||
|
# "xnnpack_delegate.h" header.
|
||||||
|
genrule(
|
||||||
|
name = "strip_xnnpack_include_hdr",
|
||||||
|
srcs = ["//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h"],
|
||||||
|
outs = ["xnnpack_delegate.h"],
|
||||||
|
cmd = """
|
||||||
|
sed 's|#include ".*common.h"|#include "common.h"|'\
|
||||||
|
"$(location //tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h)"\
|
||||||
|
> "$@"
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
# bazel build -c opt --config=ios_fat //tensorflow/lite/experimental/ios:TensorFlowLiteC_framework
|
# bazel build -c opt --config=ios_fat //tensorflow/lite/experimental/ios:TensorFlowLiteC_framework
|
||||||
tflite_ios_static_framework(
|
tflite_ios_static_framework(
|
||||||
name = "TensorFlowLiteC_framework",
|
name = "TensorFlowLiteC_framework",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
":xnnpack_delegate.h",
|
||||||
"//tensorflow/lite/c:c_api.h",
|
"//tensorflow/lite/c:c_api.h",
|
||||||
"//tensorflow/lite/c:common.h",
|
"//tensorflow/lite/c:common.h",
|
||||||
],
|
],
|
||||||
@ -105,6 +121,7 @@ cc_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"//tensorflow/lite/c:c_api.h",
|
"//tensorflow/lite/c:c_api.h",
|
||||||
"//tensorflow/lite/c:common.h",
|
"//tensorflow/lite/c:common.h",
|
||||||
|
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate.h",
|
||||||
],
|
],
|
||||||
tags = [
|
tags = [
|
||||||
"nobuilder",
|
"nobuilder",
|
||||||
@ -112,6 +129,7 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/c:c_api",
|
"//tensorflow/lite/c:c_api",
|
||||||
|
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,6 +64,7 @@ objc_library(
|
|||||||
visibility = ios_visibility_whitelist(),
|
visibility = ios_visibility_whitelist(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite/c:c_api",
|
"//tensorflow/lite/c:c_api",
|
||||||
|
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -26,6 +26,7 @@ Pod::Spec.new do |s|
|
|||||||
objc_dir + '{apis,sources}/*.{h,m,mm}',
|
objc_dir + '{apis,sources}/*.{h,m,mm}',
|
||||||
tfl_dir + 'c/c_api.h',
|
tfl_dir + 'c/c_api.h',
|
||||||
tfl_dir + 'c/common.h',
|
tfl_dir + 'c/common.h',
|
||||||
|
tfl_dir + 'delegates/xnnpack/xnnpack_delegate.h',
|
||||||
]
|
]
|
||||||
s.module_map = objc_dir + 'apis/framework.modulemap'
|
s.module_map = objc_dir + 'apis/framework.modulemap'
|
||||||
s.dependency 'TensorFlowLiteC', "~> #{s.version}"
|
s.dependency 'TensorFlowLiteC', "~> #{s.version}"
|
||||||
|
@ -26,6 +26,7 @@ Pod::Spec.new do |s|
|
|||||||
objc_dir + '{apis,sources}/*.{h,m,mm}',
|
objc_dir + '{apis,sources}/*.{h,m,mm}',
|
||||||
tfl_dir + 'c/c_api.h',
|
tfl_dir + 'c/c_api.h',
|
||||||
tfl_dir + 'c/common.h',
|
tfl_dir + 'c/common.h',
|
||||||
|
tfl_dir + 'delegates/xnnpack/xnnpack_delegate.h',
|
||||||
]
|
]
|
||||||
s.module_map = objc_dir + 'apis/framework.modulemap'
|
s.module_map = objc_dir + 'apis/framework.modulemap'
|
||||||
s.dependency 'TensorFlowLiteC', "#{s.version}"
|
s.dependency 'TensorFlowLiteC', "#{s.version}"
|
||||||
|
@ -26,6 +26,7 @@ Pod::Spec.new do |s|
|
|||||||
objc_dir + '{apis,sources}/*.{h,m,mm}',
|
objc_dir + '{apis,sources}/*.{h,m,mm}',
|
||||||
tfl_dir + 'c/c_api.h',
|
tfl_dir + 'c/c_api.h',
|
||||||
tfl_dir + 'c/common.h',
|
tfl_dir + 'c/common.h',
|
||||||
|
tfl_dir + 'delegates/xnnpack/xnnpack_delegate.h',
|
||||||
]
|
]
|
||||||
s.module_map = objc_dir + 'apis/framework.modulemap'
|
s.module_map = objc_dir + 'apis/framework.modulemap'
|
||||||
s.dependency 'TensorFlowLiteC', '~> 0.0.1-nightly'
|
s.dependency 'TensorFlowLiteC', '~> 0.0.1-nightly'
|
||||||
|
@ -25,6 +25,27 @@ NS_ASSUME_NONNULL_BEGIN
|
|||||||
*/
|
*/
|
||||||
@property(nonatomic) NSUInteger numberOfThreads;
|
@property(nonatomic) NSUInteger numberOfThreads;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK).
|
||||||
|
*
|
||||||
|
* Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided via the
|
||||||
|
* XNNPACK delegate. Currently, this is restricted to a subset of floating point operations.
|
||||||
|
* Eventually, we plan to enable this by default, as it can provide significant performance benefits
|
||||||
|
* for many classes of floating point models. See
|
||||||
|
* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
|
||||||
|
* for more details.
|
||||||
|
*
|
||||||
|
* Things to keep in mind when enabling this flag:
|
||||||
|
*
|
||||||
|
* * Startup time and resize time may increase.
|
||||||
|
* * Baseline memory consumption may increase.
|
||||||
|
* * Compatibility with other delegates (e.g., GPU) has not been fully validated.
|
||||||
|
* * Quantized models will not see any benefit.
|
||||||
|
*
|
||||||
|
* WARNING: This is an experimental interface that is subject to change.
|
||||||
|
*/
|
||||||
|
@property(nonatomic) BOOL useXNNPACK;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes a new instance of `TFLInterpreterOptions`.
|
* Initializes a new instance of `TFLInterpreterOptions`.
|
||||||
*
|
*
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
|
#import "tensorflow/lite/experimental/objc/apis/TFLTensor.h"
|
||||||
|
|
||||||
#include "tensorflow/lite/c/c_api.h"
|
#include "tensorflow/lite/c/c_api.h"
|
||||||
|
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_BEGIN
|
NS_ASSUME_NONNULL_BEGIN
|
||||||
|
|
||||||
@ -45,6 +46,9 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
|
|||||||
/** TfLiteInterpreter backed by C API. */
|
/** TfLiteInterpreter backed by C API. */
|
||||||
@property(nonatomic, nullable) TfLiteInterpreter *interpreter;
|
@property(nonatomic, nullable) TfLiteInterpreter *interpreter;
|
||||||
|
|
||||||
|
/** TfLiteDelegate backed by C API. */
|
||||||
|
@property(nonatomic, nullable) TfLiteDelegate *xnnpack_delegate;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
@implementation TFLInterpreter
|
@implementation TFLInterpreter
|
||||||
@ -53,6 +57,7 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
|
|||||||
|
|
||||||
- (void)dealloc {
|
- (void)dealloc {
|
||||||
TfLiteInterpreterDelete(_interpreter);
|
TfLiteInterpreterDelete(_interpreter);
|
||||||
|
TfLiteXNNPackDelegateDelete(_xnnpack_delegate);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma mark - Public
|
#pragma mark - Public
|
||||||
@ -104,6 +109,16 @@ static void TFLInterpreterErrorReporter(void *user_data, const char *format, va_
|
|||||||
}
|
}
|
||||||
TfLiteInterpreterOptionsSetErrorReporter(cOptions, TFLInterpreterErrorReporter, nullptr);
|
TfLiteInterpreterOptionsSetErrorReporter(cOptions, TFLInterpreterErrorReporter, nullptr);
|
||||||
|
|
||||||
|
if (options.useXNNPACK) {
|
||||||
|
TfLiteXNNPackDelegateOptions xnnpack_options = TfLiteXNNPackDelegateOptionsDefault();
|
||||||
|
if (options.numberOfThreads > 0) {
|
||||||
|
xnnpack_options.num_threads = (int32_t)options.numberOfThreads;
|
||||||
|
}
|
||||||
|
|
||||||
|
_xnnpack_delegate = TfLiteXNNPackDelegateCreate(&xnnpack_options);
|
||||||
|
TfLiteInterpreterOptionsAddDelegate(cOptions, _xnnpack_delegate);
|
||||||
|
}
|
||||||
|
|
||||||
_interpreter = TfLiteInterpreterCreate(model, cOptions);
|
_interpreter = TfLiteInterpreterCreate(model, cOptions);
|
||||||
if (_interpreter == nullptr) {
|
if (_interpreter == nullptr) {
|
||||||
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
|
[TFLErrorUtil saveInterpreterErrorWithCode:TFLInterpreterErrorCodeFailedToCreateInterpreter
|
||||||
|
@ -32,6 +32,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||||||
TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init];
|
TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init];
|
||||||
XCTAssertNotNil(options);
|
XCTAssertNotNil(options);
|
||||||
XCTAssertEqual(options.numberOfThreads, 0);
|
XCTAssertEqual(options.numberOfThreads, 0);
|
||||||
|
XCTAssertFalse(options.useXNNPACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
- (void)testSetNumberOfThread {
|
- (void)testSetNumberOfThread {
|
||||||
@ -44,6 +45,14 @@ NS_ASSUME_NONNULL_BEGIN
|
|||||||
XCTAssertEqual(options.numberOfThreads, 3);
|
XCTAssertEqual(options.numberOfThreads, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
- (void)testUseXNNPACK {
|
||||||
|
TFLInterpreterOptions *options = [[TFLInterpreterOptions alloc] init];
|
||||||
|
options.useXNNPACK = YES;
|
||||||
|
XCTAssertTrue(options.useXNNPACK);
|
||||||
|
options.useXNNPACK = NO;
|
||||||
|
XCTAssertFalse(options.useXNNPACK);
|
||||||
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
NS_ASSUME_NONNULL_END
|
NS_ASSUME_NONNULL_END
|
||||||
|
@ -39,6 +39,9 @@ public final class Interpreter {
|
|||||||
/// The underlying `TfLiteInterpreter` C pointer.
|
/// The underlying `TfLiteInterpreter` C pointer.
|
||||||
private var cInterpreter: CInterpreter?
|
private var cInterpreter: CInterpreter?
|
||||||
|
|
||||||
|
/// The underlying `TfLiteDelegate` C pointer for XNNPACK delegate.
|
||||||
|
private var cXNNPackDelegate: Delegate.CDelegate?
|
||||||
|
|
||||||
/// Creates a new instance with the given values.
|
/// Creates a new instance with the given values.
|
||||||
///
|
///
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
@ -78,6 +81,14 @@ public final class Interpreter {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
delegates?.forEach { TfLiteInterpreterOptionsAddDelegate(cInterpreterOptions, $0.cDelegate) }
|
delegates?.forEach { TfLiteInterpreterOptionsAddDelegate(cInterpreterOptions, $0.cDelegate) }
|
||||||
|
|
||||||
|
// Configure the XNNPack delegate after the other delegates explicitly added by the user.
|
||||||
|
options.map {
|
||||||
|
if $0.isXNNPackEnabled {
|
||||||
|
configureXNNPack(options: $0, cInterpreterOptions: cInterpreterOptions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
guard let cInterpreter = TfLiteInterpreterCreate(model.cModel, cInterpreterOptions) else {
|
guard let cInterpreter = TfLiteInterpreterCreate(model.cModel, cInterpreterOptions) else {
|
||||||
throw InterpreterError.failedToCreateInterpreter
|
throw InterpreterError.failedToCreateInterpreter
|
||||||
}
|
}
|
||||||
@ -86,6 +97,7 @@ public final class Interpreter {
|
|||||||
|
|
||||||
deinit {
|
deinit {
|
||||||
TfLiteInterpreterDelete(cInterpreter)
|
TfLiteInterpreterDelete(cInterpreter)
|
||||||
|
TfLiteXNNPackDelegateDelete(cXNNPackDelegate)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the interpreter to perform inference from the loaded graph.
|
/// Invokes the interpreter to perform inference from the loaded graph.
|
||||||
@ -201,12 +213,13 @@ public final class Interpreter {
|
|||||||
guard case 0...maxIndex = index else {
|
guard case 0...maxIndex = index else {
|
||||||
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
throw InterpreterError.invalidTensorIndex(index: index, maxIndex: maxIndex)
|
||||||
}
|
}
|
||||||
guard TfLiteInterpreterResizeInputTensor(
|
guard
|
||||||
cInterpreter,
|
TfLiteInterpreterResizeInputTensor(
|
||||||
Int32(index),
|
cInterpreter,
|
||||||
shape.int32Dimensions,
|
Int32(index),
|
||||||
Int32(shape.rank)
|
shape.int32Dimensions,
|
||||||
) == kTfLiteOk
|
Int32(shape.rank)
|
||||||
|
) == kTfLiteOk
|
||||||
else {
|
else {
|
||||||
throw InterpreterError.failedToResizeInputTensor(index: index)
|
throw InterpreterError.failedToResizeInputTensor(index: index)
|
||||||
}
|
}
|
||||||
@ -236,11 +249,11 @@ public final class Interpreter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if swift(>=5.0)
|
#if swift(>=5.0)
|
||||||
let status = data.withUnsafeBytes {
|
let status = data.withUnsafeBytes {
|
||||||
TfLiteTensorCopyFromBuffer(cTensor, $0.baseAddress, data.count)
|
TfLiteTensorCopyFromBuffer(cTensor, $0.baseAddress, data.count)
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
let status = data.withUnsafeBytes { TfLiteTensorCopyFromBuffer(cTensor, $0, data.count) }
|
let status = data.withUnsafeBytes { TfLiteTensorCopyFromBuffer(cTensor, $0, data.count) }
|
||||||
#endif // swift(>=5.0)
|
#endif // swift(>=5.0)
|
||||||
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
guard status == kTfLiteOk else { throw InterpreterError.failedToCopyDataToInputTensor }
|
||||||
return try input(at: index)
|
return try input(at: index)
|
||||||
@ -256,6 +269,18 @@ public final class Interpreter {
|
|||||||
throw InterpreterError.failedToAllocateTensors
|
throw InterpreterError.failedToAllocateTensors
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MARK: - Private
|
||||||
|
|
||||||
|
private func configureXNNPack(options: Options, cInterpreterOptions: OpaquePointer) {
|
||||||
|
var cXNNPackOptions = TfLiteXNNPackDelegateOptionsDefault()
|
||||||
|
if let threadCount = options.threadCount, threadCount > 0 {
|
||||||
|
cXNNPackOptions.num_threads = Int32(threadCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
cXNNPackDelegate = TfLiteXNNPackDelegateCreate(&cXNNPackOptions)
|
||||||
|
TfLiteInterpreterOptionsAddDelegate(cInterpreterOptions, cXNNPackDelegate)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
extension Interpreter {
|
extension Interpreter {
|
||||||
@ -265,6 +290,28 @@ extension Interpreter {
|
|||||||
/// indicating that the `Interpreter` will decide the number of threads to use.
|
/// indicating that the `Interpreter` will decide the number of threads to use.
|
||||||
public var threadCount: Int? = nil
|
public var threadCount: Int? = nil
|
||||||
|
|
||||||
|
/// Indicates whether an optimized set of floating point CPU kernels, provided by XNNPACK, is
|
||||||
|
/// enabled.
|
||||||
|
///
|
||||||
|
/// - Experiment:
|
||||||
|
/// Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
|
||||||
|
/// via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
|
||||||
|
/// operations. Eventually, we plan to enable this by default, as it can provide significant
|
||||||
|
/// performance benefits for many classes of floating point models. See
|
||||||
|
/// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
|
||||||
|
/// for more details.
|
||||||
|
///
|
||||||
|
/// - Important:
|
||||||
|
/// Things to keep in mind when enabling this flag:
|
||||||
|
///
|
||||||
|
/// * Startup time and resize time may increase.
|
||||||
|
/// * Baseline memory consumption may increase.
|
||||||
|
/// * Compatibility with other delegates (e.g., GPU) has not been fully validated.
|
||||||
|
/// * Quantized models will not see any benefit.
|
||||||
|
///
|
||||||
|
/// - Warning: This is an experimental interface that is subject to change.
|
||||||
|
public var isXNNPackEnabled: Bool = false
|
||||||
|
|
||||||
/// Creates a new instance with the default values.
|
/// Creates a new instance with the default values.
|
||||||
public init() {}
|
public init() {}
|
||||||
}
|
}
|
||||||
|
@ -142,10 +142,12 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
func testResizeInputTensorAtIndexToShape_ThrowsInvalidIndex() {
|
||||||
XCTAssertThrowsError(try interpreter.resizeInput(
|
XCTAssertThrowsError(
|
||||||
at: AddModel.invalidIndex,
|
try interpreter.resizeInput(
|
||||||
to: [2, 2, 3]
|
at: AddModel.invalidIndex,
|
||||||
)) { error in
|
to: [2, 2, 3]
|
||||||
|
)
|
||||||
|
) { error in
|
||||||
let maxIndex = AddModel.inputTensorCount - 1
|
let maxIndex = AddModel.inputTensorCount - 1
|
||||||
self.assertEqualErrors(
|
self.assertEqualErrors(
|
||||||
actual: error,
|
actual: error,
|
||||||
@ -162,10 +164,12 @@ class InterpreterTests: XCTestCase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
func testCopyDataToInputTensorAtIndex_ThrowsInvalidIndex() {
|
||||||
XCTAssertThrowsError(try interpreter.copy(
|
XCTAssertThrowsError(
|
||||||
AddModel.inputData,
|
try interpreter.copy(
|
||||||
toInputAt: AddModel.invalidIndex
|
AddModel.inputData,
|
||||||
)) { error in
|
toInputAt: AddModel.invalidIndex
|
||||||
|
)
|
||||||
|
) { error in
|
||||||
let maxIndex = AddModel.inputTensorCount - 1
|
let maxIndex = AddModel.inputTensorCount - 1
|
||||||
self.assertEqualErrors(
|
self.assertEqualErrors(
|
||||||
actual: error,
|
actual: error,
|
||||||
@ -178,10 +182,12 @@ class InterpreterTests: XCTestCase {
|
|||||||
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
try interpreter.resizeInput(at: AddModel.validIndex, to: AddModel.shape)
|
||||||
try interpreter.allocateTensors()
|
try interpreter.allocateTensors()
|
||||||
let invalidData = Data(count: AddModel.dataCount - 1)
|
let invalidData = Data(count: AddModel.dataCount - 1)
|
||||||
XCTAssertThrowsError(try interpreter.copy(
|
XCTAssertThrowsError(
|
||||||
invalidData,
|
try interpreter.copy(
|
||||||
toInputAt: AddModel.validIndex
|
invalidData,
|
||||||
)) { error in
|
toInputAt: AddModel.validIndex
|
||||||
|
)
|
||||||
|
) { error in
|
||||||
self.assertEqualErrors(
|
self.assertEqualErrors(
|
||||||
actual: error,
|
actual: error,
|
||||||
expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount)
|
expected: .invalidTensorDataCount(provided: invalidData.count, required: AddModel.dataCount)
|
||||||
@ -223,12 +229,20 @@ class InterpreterOptionsTests: XCTestCase {
|
|||||||
func testInitWithDefaultValues() {
|
func testInitWithDefaultValues() {
|
||||||
let options = Interpreter.Options()
|
let options = Interpreter.Options()
|
||||||
XCTAssertNil(options.threadCount)
|
XCTAssertNil(options.threadCount)
|
||||||
|
XCTAssertFalse(options.isXNNPackEnabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testInitWithCustomValues() {
|
func testInitWithCustomValues() {
|
||||||
var options = Interpreter.Options()
|
var options = Interpreter.Options()
|
||||||
|
|
||||||
options.threadCount = 2
|
options.threadCount = 2
|
||||||
XCTAssertEqual(options.threadCount, 2)
|
XCTAssertEqual(options.threadCount, 2)
|
||||||
|
|
||||||
|
options.isXNNPackEnabled = false
|
||||||
|
XCTAssertFalse(options.isXNNPackEnabled)
|
||||||
|
|
||||||
|
options.isXNNPackEnabled = true
|
||||||
|
XCTAssertTrue(options.isXNNPackEnabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testEquatable() {
|
func testEquatable() {
|
||||||
@ -242,6 +256,15 @@ class InterpreterOptionsTests: XCTestCase {
|
|||||||
|
|
||||||
options2.threadCount = 3
|
options2.threadCount = 3
|
||||||
XCTAssertNotEqual(options1, options2)
|
XCTAssertNotEqual(options1, options2)
|
||||||
|
|
||||||
|
options2.threadCount = 2
|
||||||
|
XCTAssertEqual(options1, options2)
|
||||||
|
|
||||||
|
options2.isXNNPackEnabled = true
|
||||||
|
XCTAssertNotEqual(options1, options2)
|
||||||
|
|
||||||
|
options1.isXNNPackEnabled = true
|
||||||
|
XCTAssertEqual(options1, options2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -326,14 +349,15 @@ extension Array {
|
|||||||
init?(unsafeData: Data) {
|
init?(unsafeData: Data) {
|
||||||
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil }
|
||||||
#if swift(>=5.0)
|
#if swift(>=5.0)
|
||||||
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
|
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) }
|
||||||
#else
|
#else
|
||||||
self = unsafeData.withUnsafeBytes {
|
self = unsafeData.withUnsafeBytes {
|
||||||
.init(UnsafeBufferPointer<Element>(
|
.init(
|
||||||
start: $0,
|
UnsafeBufferPointer<Element>(
|
||||||
count: unsafeData.count / MemoryLayout<Element>.stride
|
start: $0,
|
||||||
))
|
count: unsafeData.count / MemoryLayout<Element>.stride
|
||||||
}
|
))
|
||||||
|
}
|
||||||
#endif // swift(>=5.0)
|
#endif // swift(>=5.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user