Add flag for using optimized TFLite CPU kernels on Android
Add an experimental flag which allows opting in to a set of highly optimized floating point kernels provided via the XNNPACK delegate. This is offered as a preview, with the plan to enable these kernels by default in a future release. The flag can be enabled via: Interpreter.Options options = new Interpreter.Options().setUseXNNPACK(true); See tensorflow/lite/delegates/xnnpack/README.md for more details about these kernels and the associated delegate functionality. PiperOrigin-RevId: 316909226 Change-Id: Ib60cf259225b8a48a9830ccbb24ec10534b038ce
This commit is contained in:
parent
2779d9e29d
commit
23d482eaa2
|
@ -21,6 +21,7 @@ cc_library(
|
||||||
linkstatic = True,
|
linkstatic = True,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite:kernel_api",
|
"//tensorflow/lite:kernel_api",
|
||||||
|
"//tensorflow/lite:minimal_logging",
|
||||||
"//tensorflow/lite:util",
|
"//tensorflow/lite:util",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
@ -47,6 +48,7 @@ cc_library(
|
||||||
linkstatic = True,
|
linkstatic = True,
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/lite:kernel_api",
|
"//tensorflow/lite:kernel_api",
|
||||||
|
"//tensorflow/lite:minimal_logging",
|
||||||
"//tensorflow/lite:util",
|
"//tensorflow/lite:util",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
"//tensorflow/lite/schema:schema_fbs",
|
"//tensorflow/lite/schema:schema_fbs",
|
||||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||||
#include "tensorflow/lite/builtin_ops.h"
|
#include "tensorflow/lite/builtin_ops.h"
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/minimal_logging.h"
|
||||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
@ -52,6 +53,8 @@ class Delegate {
|
||||||
pthreadpool_create(static_cast<size_t>(options->num_threads)));
|
pthreadpool_create(static_cast<size_t>(options->num_threads)));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
|
||||||
|
"Created TensorFlow Lite XNNPACK delegate for CPU.");
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);
|
TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);
|
||||||
|
|
|
@ -408,6 +408,7 @@ tflite_jni_binary(
|
||||||
"//tensorflow/lite/c:c_api",
|
"//tensorflow/lite/c:c_api",
|
||||||
"//tensorflow/lite/c:c_api_experimental",
|
"//tensorflow/lite/c:c_api_experimental",
|
||||||
"//tensorflow/lite/delegates/nnapi/java/src/main/native",
|
"//tensorflow/lite/delegates/nnapi/java/src/main/native",
|
||||||
|
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||||
"//tensorflow/lite/java/src/main/native",
|
"//tensorflow/lite/java/src/main/native",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -137,10 +137,37 @@ public final class Interpreter implements AutoCloseable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK).
|
||||||
|
*
|
||||||
|
* <p>Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
|
||||||
|
* via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
|
||||||
|
* operations. Eventually, we plan to enable this by default, as it can provide significant
|
||||||
|
* peformance benefits for many classes of floating point models. See
|
||||||
|
* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
|
||||||
|
* for more details.
|
||||||
|
*
|
||||||
|
* <p>Things to keep in mind when enabling this flag:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Startup time and resize time may increase.
|
||||||
|
* <li>Baseline memory consumption may increase.
|
||||||
|
* <li>Compatibility with other delegates (e.g., GPU) has not been fully validated.
|
||||||
|
* <li>Quantized models will not see any benefit.
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <p>WARNING: This is an experimental interface that is subject to change.
|
||||||
|
*/
|
||||||
|
public Options setUseXNNPACK(boolean useXNNPACK) {
|
||||||
|
this.useXNNPACK = useXNNPACK;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
int numThreads = -1;
|
int numThreads = -1;
|
||||||
Boolean useNNAPI;
|
Boolean useNNAPI;
|
||||||
Boolean allowFp16PrecisionForFp32;
|
Boolean allowFp16PrecisionForFp32;
|
||||||
Boolean allowBufferHandleOutput;
|
Boolean allowBufferHandleOutput;
|
||||||
|
Boolean useXNNPACK;
|
||||||
final List<Delegate> delegates = new ArrayList<>();
|
final List<Delegate> delegates = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||||
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
|
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
|
||||||
}
|
}
|
||||||
applyDelegates(options);
|
applyDelegates(options);
|
||||||
|
if (options.useXNNPACK != null) {
|
||||||
|
useXNNPACK(
|
||||||
|
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
|
||||||
|
}
|
||||||
allocateTensors(interpreterHandle, errorHandle);
|
allocateTensors(interpreterHandle, errorHandle);
|
||||||
this.isMemoryAllocated = true;
|
this.isMemoryAllocated = true;
|
||||||
}
|
}
|
||||||
|
@ -438,6 +442,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||||
|
|
||||||
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
|
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
|
||||||
|
|
||||||
|
private static native void useXNNPACK(
|
||||||
|
long interpreterHandle, long errorHandle, boolean state, int numThreads);
|
||||||
|
|
||||||
private static native long createErrorReporter(int size);
|
private static native long createErrorReporter(int size);
|
||||||
|
|
||||||
private static native long createModel(String modelPathOrBuffer, long errorHandle);
|
private static native long createModel(String modelPathOrBuffer, long errorHandle);
|
||||||
|
|
|
@ -31,6 +31,7 @@ cc_library(
|
||||||
"//tensorflow/lite:string_util",
|
"//tensorflow/lite:string_util",
|
||||||
"//tensorflow/lite:util",
|
"//tensorflow/lite:util",
|
||||||
"//tensorflow/lite/c:common",
|
"//tensorflow/lite/c:common",
|
||||||
|
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
|
||||||
"//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels",
|
"//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels",
|
||||||
"//tensorflow/lite/java/jni",
|
"//tensorflow/lite/java/jni",
|
||||||
],
|
],
|
||||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
#include <jni.h>
|
#include <jni.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
|
@ -20,6 +21,7 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||||
#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
|
#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
|
||||||
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
|
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
|
||||||
#include "tensorflow/lite/util.h"
|
#include "tensorflow/lite/util.h"
|
||||||
|
@ -323,6 +325,59 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
|
||||||
interpreter->SetAllowBufferHandleOutput(allow);
|
interpreter->SetAllowBufferHandleOutput(allow);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JNIEXPORT void JNICALL
|
||||||
|
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
|
||||||
|
jint num_threads) {
|
||||||
|
// If not using xnnpack, simply don't apply the delegate.
|
||||||
|
if (!state) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
tflite_api_dispatcher::Interpreter* interpreter =
|
||||||
|
convertLongToInterpreter(env, handle);
|
||||||
|
if (interpreter == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferErrorReporter* error_reporter =
|
||||||
|
convertLongToErrorReporter(env, error_handle);
|
||||||
|
if (error_reporter == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We use dynamic loading to avoid taking a hard dependency on XNNPack.
|
||||||
|
// This allows clients that use trimmed builds to save on binary size.
|
||||||
|
auto xnnpack_options_default =
|
||||||
|
reinterpret_cast<decltype(TfLiteXNNPackDelegateOptionsDefault)*>(
|
||||||
|
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateOptionsDefault"));
|
||||||
|
auto xnnpack_create =
|
||||||
|
reinterpret_cast<decltype(TfLiteXNNPackDelegateCreate)*>(
|
||||||
|
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateCreate"));
|
||||||
|
auto xnnpack_delete =
|
||||||
|
reinterpret_cast<decltype(TfLiteXNNPackDelegateDelete)*>(
|
||||||
|
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateDelete"));
|
||||||
|
|
||||||
|
if (xnnpack_options_default && xnnpack_create && xnnpack_delete) {
|
||||||
|
TfLiteXNNPackDelegateOptions options = xnnpack_options_default();
|
||||||
|
if (num_threads > 0) {
|
||||||
|
options.num_threads = num_threads;
|
||||||
|
}
|
||||||
|
tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
|
||||||
|
xnnpack_create(&options), xnnpack_delete);
|
||||||
|
if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) !=
|
||||||
|
kTfLiteOk) {
|
||||||
|
ThrowException(env, kIllegalArgumentException,
|
||||||
|
"Internal error: Failed to apply XNNPACK delegate: %s",
|
||||||
|
error_reporter->CachedErrorMessage());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ThrowException(env, kIllegalArgumentException,
|
||||||
|
"Failed to load XNNPACK delegate from current runtime. "
|
||||||
|
"Have you added the necessary dependencies?");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
|
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
|
||||||
jclass clazz,
|
jclass clazz,
|
||||||
|
|
|
@ -54,6 +54,16 @@ public final class InterpreterMobileNetTest {
|
||||||
runMobileNetFloatTest(new Interpreter.Options().setNumThreads(2));
|
runMobileNetFloatTest(new Interpreter.Options().setNumThreads(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMobileNetEnhancedCpuKernels() {
|
||||||
|
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMobileNetEnhancedCpuKernelsMultithreaded() {
|
||||||
|
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true).setNumThreads(2));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMobileNetQuantized() {
|
public void testMobileNetQuantized() {
|
||||||
runMobileNetQuantizedTest(new Interpreter.Options());
|
runMobileNetQuantizedTest(new Interpreter.Options());
|
||||||
|
@ -64,6 +74,12 @@ public final class InterpreterMobileNetTest {
|
||||||
runMobileNetQuantizedTest(new Interpreter.Options().setNumThreads(2));
|
runMobileNetQuantizedTest(new Interpreter.Options().setNumThreads(2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMobileNetQuantizedEnhancedCpu() {
|
||||||
|
// The "enhanced CPU flag" should only impact float models, this is a sanity test to confirm.
|
||||||
|
runMobileNetQuantizedTest(new Interpreter.Options().setUseXNNPACK(true));
|
||||||
|
}
|
||||||
|
|
||||||
private static void runMobileNetFloatTest(Interpreter.Options options) {
|
private static void runMobileNetFloatTest(Interpreter.Options options) {
|
||||||
ByteBuffer img =
|
ByteBuffer img =
|
||||||
TestUtils.getTestImageAsFloatByteBuffer(
|
TestUtils.getTestImageAsFloatByteBuffer(
|
||||||
|
|
|
@ -409,6 +409,38 @@ public final class InterpreterTest {
|
||||||
interpreter.close();
|
interpreter.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUseXNNPACK() throws Exception {
|
||||||
|
Interpreter interpreter =
|
||||||
|
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
|
||||||
|
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||||
|
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||||
|
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||||
|
float[][][][] fourD = {threeD, threeD};
|
||||||
|
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||||
|
interpreter.run(fourD, parsedOutputs);
|
||||||
|
float[] outputOneD = parsedOutputs[0][0][0];
|
||||||
|
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||||
|
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||||
|
interpreter.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testResizeWithEnhancedCpuKernels() throws Exception {
|
||||||
|
Interpreter interpreter =
|
||||||
|
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
|
||||||
|
float[] input = {1.f};
|
||||||
|
float[] output = new float[1];
|
||||||
|
interpreter.run(input, output);
|
||||||
|
assertThat(output).usingTolerance(0.1f).containsExactly(new float[] {3.f}).inOrder();
|
||||||
|
|
||||||
|
// The new input shape should trigger a resize. Inference should still work properly.
|
||||||
|
float[] input2 = {1.f, 2.f};
|
||||||
|
float[] output2 = new float[2];
|
||||||
|
interpreter.run(input2, output2);
|
||||||
|
assertThat(output2).usingTolerance(0.1f).containsExactly(new float[] {3.f, 6.f}).inOrder();
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRedundantClose() throws Exception {
|
public void testRedundantClose() throws Exception {
|
||||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER);
|
Interpreter interpreter = new Interpreter(MODEL_BUFFER);
|
||||||
|
|
Loading…
Reference in New Issue