Add flag for using optimized TFLite CPU kernels on Android

Add an experimental flag which allows opting in to a set of highly
optimized floating point kernels provided via the XNNPACK delegate.
This is offered as a preview, with the plan to enable these kernels
by default in a future release. The flag can be enabled via:

  Interpreter.Options options =
      new Interpreter.Options().setUseXNNPACK(true);

See tensorflow/lite/delegates/xnnpack/README.md for more details about
these kernels and the associated delegate functionality.

PiperOrigin-RevId: 316909226
Change-Id: Ib60cf259225b8a48a9830ccbb24ec10534b038ce
This commit is contained in:
Jared Duke 2020-06-17 09:57:30 -07:00 committed by TensorFlower Gardener
parent 2779d9e29d
commit 23d482eaa2
9 changed files with 144 additions and 0 deletions

View File

@ -21,6 +21,7 @@ cc_library(
linkstatic = True, linkstatic = True,
deps = [ deps = [
"//tensorflow/lite:kernel_api", "//tensorflow/lite:kernel_api",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite:util", "//tensorflow/lite:util",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
@ -47,6 +48,7 @@ cc_library(
linkstatic = True, linkstatic = True,
deps = [ deps = [
"//tensorflow/lite:kernel_api", "//tensorflow/lite:kernel_api",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite:util", "//tensorflow/lite:util",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/minimal_logging.h"
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h" #include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
namespace tflite { namespace tflite {
@ -52,6 +53,8 @@ class Delegate {
pthreadpool_create(static_cast<size_t>(options->num_threads))); pthreadpool_create(static_cast<size_t>(options->num_threads)));
} }
#endif #endif
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
"Created TensorFlow Lite XNNPACK delegate for CPU.");
} }
TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context); TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);

View File

@ -408,6 +408,7 @@ tflite_jni_binary(
"//tensorflow/lite/c:c_api", "//tensorflow/lite/c:c_api",
"//tensorflow/lite/c:c_api_experimental", "//tensorflow/lite/c:c_api_experimental",
"//tensorflow/lite/delegates/nnapi/java/src/main/native", "//tensorflow/lite/delegates/nnapi/java/src/main/native",
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"//tensorflow/lite/java/src/main/native", "//tensorflow/lite/java/src/main/native",
], ],
) )

View File

@ -137,10 +137,37 @@ public final class Interpreter implements AutoCloseable {
return this; return this;
} }
/**
* Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK).
*
* <p>Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
* via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
* operations. Eventually, we plan to enable this by default, as it can provide significant
* peformance benefits for many classes of floating point models. See
* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
* for more details.
*
* <p>Things to keep in mind when enabling this flag:
*
* <ul>
* <li>Startup time and resize time may increase.
* <li>Baseline memory consumption may increase.
* <li>Compatibility with other delegates (e.g., GPU) has not been fully validated.
* <li>Quantized models will not see any benefit.
* </ul>
*
* <p>WARNING: This is an experimental interface that is subject to change.
*/
public Options setUseXNNPACK(boolean useXNNPACK) {
this.useXNNPACK = useXNNPACK;
return this;
}
int numThreads = -1; int numThreads = -1;
Boolean useNNAPI; Boolean useNNAPI;
Boolean allowFp16PrecisionForFp32; Boolean allowFp16PrecisionForFp32;
Boolean allowBufferHandleOutput; Boolean allowBufferHandleOutput;
Boolean useXNNPACK;
final List<Delegate> delegates = new ArrayList<>(); final List<Delegate> delegates = new ArrayList<>();
} }

View File

@ -80,6 +80,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue()); allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
} }
applyDelegates(options); applyDelegates(options);
if (options.useXNNPACK != null) {
useXNNPACK(
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
}
allocateTensors(interpreterHandle, errorHandle); allocateTensors(interpreterHandle, errorHandle);
this.isMemoryAllocated = true; this.isMemoryAllocated = true;
} }
@ -438,6 +442,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow); private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
private static native void useXNNPACK(
long interpreterHandle, long errorHandle, boolean state, int numThreads);
private static native long createErrorReporter(int size); private static native long createErrorReporter(int size);
private static native long createModel(String modelPathOrBuffer, long errorHandle); private static native long createModel(String modelPathOrBuffer, long errorHandle);

View File

@ -31,6 +31,7 @@ cc_library(
"//tensorflow/lite:string_util", "//tensorflow/lite:string_util",
"//tensorflow/lite:util", "//tensorflow/lite:util",
"//tensorflow/lite/c:common", "//tensorflow/lite/c:common",
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
"//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels", "//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels",
"//tensorflow/lite/java/jni", "//tensorflow/lite/java/jni",
], ],

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <dlfcn.h>
#include <jni.h> #include <jni.h>
#include <stdio.h> #include <stdio.h>
#include <time.h> #include <time.h>
@ -20,6 +21,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h" #include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
#include "tensorflow/lite/java/src/main/native/jni_utils.h" #include "tensorflow/lite/java/src/main/native/jni_utils.h"
#include "tensorflow/lite/util.h" #include "tensorflow/lite/util.h"
@ -323,6 +325,59 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
interpreter->SetAllowBufferHandleOutput(allow); interpreter->SetAllowBufferHandleOutput(allow);
} }
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
jint num_threads) {
// If not using xnnpack, simply don't apply the delegate.
if (!state) {
return;
}
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) {
return;
}
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
if (error_reporter == nullptr) {
return;
}
// We use dynamic loading to avoid taking a hard dependency on XNNPack.
// This allows clients that use trimmed builds to save on binary size.
auto xnnpack_options_default =
reinterpret_cast<decltype(TfLiteXNNPackDelegateOptionsDefault)*>(
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateOptionsDefault"));
auto xnnpack_create =
reinterpret_cast<decltype(TfLiteXNNPackDelegateCreate)*>(
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateCreate"));
auto xnnpack_delete =
reinterpret_cast<decltype(TfLiteXNNPackDelegateDelete)*>(
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateDelete"));
if (xnnpack_options_default && xnnpack_create && xnnpack_delete) {
TfLiteXNNPackDelegateOptions options = xnnpack_options_default();
if (num_threads > 0) {
options.num_threads = num_threads;
}
tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
xnnpack_create(&options), xnnpack_delete);
if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) !=
kTfLiteOk) {
ThrowException(env, kIllegalArgumentException,
"Internal error: Failed to apply XNNPACK delegate: %s",
error_reporter->CachedErrorMessage());
}
} else {
ThrowException(env, kIllegalArgumentException,
"Failed to load XNNPACK delegate from current runtime. "
"Have you added the necessary dependencies?");
}
}
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env, Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
jclass clazz, jclass clazz,

View File

@ -54,6 +54,16 @@ public final class InterpreterMobileNetTest {
runMobileNetFloatTest(new Interpreter.Options().setNumThreads(2)); runMobileNetFloatTest(new Interpreter.Options().setNumThreads(2));
} }
@Test
public void testMobileNetEnhancedCpuKernels() {
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true));
}
@Test
public void testMobileNetEnhancedCpuKernelsMultithreaded() {
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true).setNumThreads(2));
}
@Test @Test
public void testMobileNetQuantized() { public void testMobileNetQuantized() {
runMobileNetQuantizedTest(new Interpreter.Options()); runMobileNetQuantizedTest(new Interpreter.Options());
@ -64,6 +74,12 @@ public final class InterpreterMobileNetTest {
runMobileNetQuantizedTest(new Interpreter.Options().setNumThreads(2)); runMobileNetQuantizedTest(new Interpreter.Options().setNumThreads(2));
} }
@Test
public void testMobileNetQuantizedEnhancedCpu() {
// The "enhanced CPU flag" should only impact float models, this is a sanity test to confirm.
runMobileNetQuantizedTest(new Interpreter.Options().setUseXNNPACK(true));
}
private static void runMobileNetFloatTest(Interpreter.Options options) { private static void runMobileNetFloatTest(Interpreter.Options options) {
ByteBuffer img = ByteBuffer img =
TestUtils.getTestImageAsFloatByteBuffer( TestUtils.getTestImageAsFloatByteBuffer(

View File

@ -409,6 +409,38 @@ public final class InterpreterTest {
interpreter.close(); interpreter.close();
} }
@Test
public void testUseXNNPACK() throws Exception {
Interpreter interpreter =
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
interpreter.close();
}
@Test
public void testResizeWithEnhancedCpuKernels() throws Exception {
Interpreter interpreter =
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
float[] input = {1.f};
float[] output = new float[1];
interpreter.run(input, output);
assertThat(output).usingTolerance(0.1f).containsExactly(new float[] {3.f}).inOrder();
// The new input shape should trigger a resize. Inference should still work properly.
float[] input2 = {1.f, 2.f};
float[] output2 = new float[2];
interpreter.run(input2, output2);
assertThat(output2).usingTolerance(0.1f).containsExactly(new float[] {3.f, 6.f}).inOrder();
}
@Test @Test
public void testRedundantClose() throws Exception { public void testRedundantClose() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_BUFFER); Interpreter interpreter = new Interpreter(MODEL_BUFFER);