From 23d482eaa2efe2bb38de7eb4f89539be9e3aa32a Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 17 Jun 2020 09:57:30 -0700 Subject: [PATCH] Add flag for using optimized TFLite CPU kernels on Android Add an experimental flag which allows opting in to a set of highly optimized floating point kernels provided via the XNNPACK delegate. This is offered as a preview, with the plan to enable these kernels by default in a future release. The flag can be enabled via: Interpreter.Options options = new Interpreter.Options().setUseXNNPACK(true); See tensorflow/lite/delegates/xnnpack/README.md for more details about these kernels and the associated delegate functionality. PiperOrigin-RevId: 316909226 Change-Id: Ib60cf259225b8a48a9830ccbb24ec10534b038ce --- tensorflow/lite/delegates/xnnpack/BUILD | 2 + .../delegates/xnnpack/xnnpack_delegate.cc | 3 + tensorflow/lite/java/BUILD | 1 + .../java/org/tensorflow/lite/Interpreter.java | 27 +++++++++ .../lite/NativeInterpreterWrapper.java | 7 +++ tensorflow/lite/java/src/main/native/BUILD | 1 + .../native/nativeinterpreterwrapper_jni.cc | 55 +++++++++++++++++++ .../lite/InterpreterMobileNetTest.java | 16 ++++++ .../org/tensorflow/lite/InterpreterTest.java | 32 +++++++++++ 9 files changed, 144 insertions(+) diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index 5736a2995b1..97e6aea2a6b 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -21,6 +21,7 @@ cc_library( linkstatic = True, deps = [ "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", "//tensorflow/lite:util", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", @@ -47,6 +48,7 @@ cc_library( linkstatic = True, deps = [ "//tensorflow/lite:kernel_api", + "//tensorflow/lite:minimal_logging", "//tensorflow/lite:util", "//tensorflow/lite/c:common", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index c4c95b6b295..739e45f62e4 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/tools/optimize/sparsity/format_converter.h" namespace tflite { @@ -52,6 +53,8 @@ class Delegate { pthreadpool_create(static_cast(options->num_threads))); } #endif + TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO, + "Created TensorFlow Lite XNNPACK delegate for CPU."); } TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context); diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 101e98e3dd1..d0331bca3e5 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -408,6 +408,7 @@ tflite_jni_binary( "//tensorflow/lite/c:c_api", "//tensorflow/lite/c:c_api_experimental", "//tensorflow/lite/delegates/nnapi/java/src/main/native", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "//tensorflow/lite/java/src/main/native", ], ) diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 7c9c5644f47..5993ee7a037 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -137,10 +137,37 @@ public final class Interpreter implements AutoCloseable { return this; } + /** + * Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK). + * + *

Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided + * via the XNNPACK delegate. Currently, this is restricted to a subset of floating point + * operations. Eventually, we plan to enable this by default, as it can provide significant + * peformance benefits for many classes of floating point models. See + * https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md + * for more details. + * + *

Things to keep in mind when enabling this flag: + * + *

+ * + *

WARNING: This is an experimental interface that is subject to change. + */ + public Options setUseXNNPACK(boolean useXNNPACK) { + this.useXNNPACK = useXNNPACK; + return this; + } + int numThreads = -1; Boolean useNNAPI; Boolean allowFp16PrecisionForFp32; Boolean allowBufferHandleOutput; + Boolean useXNNPACK; final List delegates = new ArrayList<>(); } diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 8eb3c66f3b5..5e9a6eecf00 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -80,6 +80,10 @@ final class NativeInterpreterWrapper implements AutoCloseable { allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue()); } applyDelegates(options); + if (options.useXNNPACK != null) { + useXNNPACK( + interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads); + } allocateTensors(interpreterHandle, errorHandle); this.isMemoryAllocated = true; } @@ -438,6 +442,9 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow); + private static native void useXNNPACK( + long interpreterHandle, long errorHandle, boolean state, int numThreads); + private static native long createErrorReporter(int size); private static native long createModel(String modelPathOrBuffer, long errorHandle); diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD index fdbbc9dc72c..52f79615a9f 100644 --- a/tensorflow/lite/java/src/main/native/BUILD +++ b/tensorflow/lite/java/src/main/native/BUILD @@ -31,6 +31,7 @@ cc_library( "//tensorflow/lite:string_util", "//tensorflow/lite:util", "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only", "//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels", "//tensorflow/lite/java/jni", ], diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 690b58ac1f4..7abe0f518f0 100644 --- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -20,6 +21,7 @@ limitations under the License. #include #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h" #include "tensorflow/lite/java/src/main/native/jni_utils.h" #include "tensorflow/lite/util.h" @@ -323,6 +325,59 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput( interpreter->SetAllowBufferHandleOutput(allow); } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( + JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state, + jint num_threads) { + // If not using xnnpack, simply don't apply the delegate. + if (!state) { + return; + } + + tflite_api_dispatcher::Interpreter* interpreter = + convertLongToInterpreter(env, handle); + if (interpreter == nullptr) { + return; + } + + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) { + return; + } + + // We use dynamic loading to avoid taking a hard dependency on XNNPack. + // This allows clients that use trimmed builds to save on binary size. + auto xnnpack_options_default = + reinterpret_cast( + dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateOptionsDefault")); + auto xnnpack_create = + reinterpret_cast( + dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateCreate")); + auto xnnpack_delete = + reinterpret_cast( + dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateDelete")); + + if (xnnpack_options_default && xnnpack_create && xnnpack_delete) { + TfLiteXNNPackDelegateOptions options = xnnpack_options_default(); + if (num_threads > 0) { + options.num_threads = num_threads; + } + tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate( + xnnpack_create(&options), xnnpack_delete); + if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) != + kTfLiteOk) { + ThrowException(env, kIllegalArgumentException, + "Internal error: Failed to apply XNNPACK delegate: %s", + error_reporter->CachedErrorMessage()); + } + } else { + ThrowException(env, kIllegalArgumentException, + "Failed to load XNNPACK delegate from current runtime. " + "Have you added the necessary dependencies?"); + } +} + JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, jclass clazz, diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java index 446cf5f7b02..80b3bf3cab9 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java @@ -54,6 +54,16 @@ public final class InterpreterMobileNetTest { runMobileNetFloatTest(new Interpreter.Options().setNumThreads(2)); } + @Test + public void testMobileNetEnhancedCpuKernels() { + runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true)); + } + + @Test + public void testMobileNetEnhancedCpuKernelsMultithreaded() { + runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true).setNumThreads(2)); + } + @Test public void testMobileNetQuantized() { runMobileNetQuantizedTest(new Interpreter.Options()); @@ -64,6 +74,12 @@ public final class InterpreterMobileNetTest { runMobileNetQuantizedTest(new Interpreter.Options().setNumThreads(2)); } + @Test + public void testMobileNetQuantizedEnhancedCpu() { + // The "enhanced CPU flag" should only impact float models, this is a sanity test to confirm. + runMobileNetQuantizedTest(new Interpreter.Options().setUseXNNPACK(true)); + } + private static void runMobileNetFloatTest(Interpreter.Options options) { ByteBuffer img = TestUtils.getTestImageAsFloatByteBuffer( diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java index 3daa9fe0766..f1d4ff147b1 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -409,6 +409,38 @@ public final class InterpreterTest { interpreter.close(); } + @Test + public void testUseXNNPACK() throws Exception { + Interpreter interpreter = + new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true)); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testResizeWithEnhancedCpuKernels() throws Exception { + Interpreter interpreter = + new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true)); + float[] input = {1.f}; + float[] output = new float[1]; + interpreter.run(input, output); + assertThat(output).usingTolerance(0.1f).containsExactly(new float[] {3.f}).inOrder(); + + // The new input shape should trigger a resize. Inference should still work properly. + float[] input2 = {1.f, 2.f}; + float[] output2 = new float[2]; + interpreter.run(input2, output2); + assertThat(output2).usingTolerance(0.1f).containsExactly(new float[] {3.f, 6.f}).inOrder(); + } + @Test public void testRedundantClose() throws Exception { Interpreter interpreter = new Interpreter(MODEL_BUFFER);