Add flag for using optimized TFLite CPU kernels on Android
Add an experimental flag which allows opting in to a set of highly optimized floating point kernels provided via the XNNPACK delegate. This is offered as a preview, with the plan to enable these kernels by default in a future release. The flag can be enabled via: Interpreter.Options options = new Interpreter.Options().setUseXNNPACK(true); See tensorflow/lite/delegates/xnnpack/README.md for more details about these kernels and the associated delegate functionality. PiperOrigin-RevId: 316909226 Change-Id: Ib60cf259225b8a48a9830ccbb24ec10534b038ce
This commit is contained in:
parent
2779d9e29d
commit
23d482eaa2
|
@ -21,6 +21,7 @@ cc_library(
|
|||
linkstatic = True,
|
||||
deps = [
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
|
@ -47,6 +48,7 @@ cc_library(
|
|||
linkstatic = True,
|
||||
deps = [
|
||||
"//tensorflow/lite:kernel_api",
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/tools/optimize/sparsity/format_converter.h"
|
||||
|
||||
namespace tflite {
|
||||
|
@ -52,6 +53,8 @@ class Delegate {
|
|||
pthreadpool_create(static_cast<size_t>(options->num_threads)));
|
||||
}
|
||||
#endif
|
||||
TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
|
||||
"Created TensorFlow Lite XNNPACK delegate for CPU.");
|
||||
}
|
||||
|
||||
TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);
|
||||
|
|
|
@ -408,6 +408,7 @@ tflite_jni_binary(
|
|||
"//tensorflow/lite/c:c_api",
|
||||
"//tensorflow/lite/c:c_api_experimental",
|
||||
"//tensorflow/lite/delegates/nnapi/java/src/main/native",
|
||||
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
"//tensorflow/lite/java/src/main/native",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -137,10 +137,37 @@ public final class Interpreter implements AutoCloseable {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Experimental: Enable an optimized set of floating point CPU kernels (provided by XNNPACK).
|
||||
*
|
||||
* <p>Enabling this flag will enable use of a new, highly optimized set of CPU kernels provided
|
||||
* via the XNNPACK delegate. Currently, this is restricted to a subset of floating point
|
||||
* operations. Eventually, we plan to enable this by default, as it can provide significant
|
||||
* peformance benefits for many classes of floating point models. See
|
||||
* https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/delegates/xnnpack/README.md
|
||||
* for more details.
|
||||
*
|
||||
* <p>Things to keep in mind when enabling this flag:
|
||||
*
|
||||
* <ul>
|
||||
* <li>Startup time and resize time may increase.
|
||||
* <li>Baseline memory consumption may increase.
|
||||
* <li>Compatibility with other delegates (e.g., GPU) has not been fully validated.
|
||||
* <li>Quantized models will not see any benefit.
|
||||
* </ul>
|
||||
*
|
||||
* <p>WARNING: This is an experimental interface that is subject to change.
|
||||
*/
|
||||
public Options setUseXNNPACK(boolean useXNNPACK) {
|
||||
this.useXNNPACK = useXNNPACK;
|
||||
return this;
|
||||
}
|
||||
|
||||
int numThreads = -1;
|
||||
Boolean useNNAPI;
|
||||
Boolean allowFp16PrecisionForFp32;
|
||||
Boolean allowBufferHandleOutput;
|
||||
Boolean useXNNPACK;
|
||||
final List<Delegate> delegates = new ArrayList<>();
|
||||
}
|
||||
|
||||
|
|
|
@ -80,6 +80,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
|||
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
|
||||
}
|
||||
applyDelegates(options);
|
||||
if (options.useXNNPACK != null) {
|
||||
useXNNPACK(
|
||||
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
|
||||
}
|
||||
allocateTensors(interpreterHandle, errorHandle);
|
||||
this.isMemoryAllocated = true;
|
||||
}
|
||||
|
@ -438,6 +442,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
|||
|
||||
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
|
||||
|
||||
private static native void useXNNPACK(
|
||||
long interpreterHandle, long errorHandle, boolean state, int numThreads);
|
||||
|
||||
private static native long createErrorReporter(int size);
|
||||
|
||||
private static native long createModel(String modelPathOrBuffer, long errorHandle);
|
||||
|
|
|
@ -31,6 +31,7 @@ cc_library(
|
|||
"//tensorflow/lite:string_util",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates/xnnpack:xnnpack_delegate_hdrs_only",
|
||||
"//tensorflow/lite/experimental/tflite_api_dispatcher:tflite_api_dispatcher_with_kernels",
|
||||
"//tensorflow/lite/java/jni",
|
||||
],
|
||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <jni.h>
|
||||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
|
@ -20,6 +21,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
#include "tensorflow/lite/experimental/tflite_api_dispatcher/tflite_api_dispatcher.h"
|
||||
#include "tensorflow/lite/java/src/main/native/jni_utils.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
@ -323,6 +325,59 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
|
|||
interpreter->SetAllowBufferHandleOutput(allow);
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
|
||||
jint num_threads) {
|
||||
// If not using xnnpack, simply don't apply the delegate.
|
||||
if (!state) {
|
||||
return;
|
||||
}
|
||||
|
||||
tflite_api_dispatcher::Interpreter* interpreter =
|
||||
convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
BufferErrorReporter* error_reporter =
|
||||
convertLongToErrorReporter(env, error_handle);
|
||||
if (error_reporter == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We use dynamic loading to avoid taking a hard dependency on XNNPack.
|
||||
// This allows clients that use trimmed builds to save on binary size.
|
||||
auto xnnpack_options_default =
|
||||
reinterpret_cast<decltype(TfLiteXNNPackDelegateOptionsDefault)*>(
|
||||
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateOptionsDefault"));
|
||||
auto xnnpack_create =
|
||||
reinterpret_cast<decltype(TfLiteXNNPackDelegateCreate)*>(
|
||||
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateCreate"));
|
||||
auto xnnpack_delete =
|
||||
reinterpret_cast<decltype(TfLiteXNNPackDelegateDelete)*>(
|
||||
dlsym(RTLD_DEFAULT, "TfLiteXNNPackDelegateDelete"));
|
||||
|
||||
if (xnnpack_options_default && xnnpack_create && xnnpack_delete) {
|
||||
TfLiteXNNPackDelegateOptions options = xnnpack_options_default();
|
||||
if (num_threads > 0) {
|
||||
options.num_threads = num_threads;
|
||||
}
|
||||
tflite_api_dispatcher::Interpreter::TfLiteDelegatePtr delegate(
|
||||
xnnpack_create(&options), xnnpack_delete);
|
||||
if (interpreter->ModifyGraphWithDelegate(std::move(delegate)) !=
|
||||
kTfLiteOk) {
|
||||
ThrowException(env, kIllegalArgumentException,
|
||||
"Internal error: Failed to apply XNNPACK delegate: %s",
|
||||
error_reporter->CachedErrorMessage());
|
||||
}
|
||||
} else {
|
||||
ThrowException(env, kIllegalArgumentException,
|
||||
"Failed to load XNNPACK delegate from current runtime. "
|
||||
"Have you added the necessary dependencies?");
|
||||
}
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
|
||||
jclass clazz,
|
||||
|
|
|
@ -54,6 +54,16 @@ public final class InterpreterMobileNetTest {
|
|||
runMobileNetFloatTest(new Interpreter.Options().setNumThreads(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetEnhancedCpuKernels() {
|
||||
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetEnhancedCpuKernelsMultithreaded() {
|
||||
runMobileNetFloatTest(new Interpreter.Options().setUseXNNPACK(true).setNumThreads(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetQuantized() {
|
||||
runMobileNetQuantizedTest(new Interpreter.Options());
|
||||
|
@ -64,6 +74,12 @@ public final class InterpreterMobileNetTest {
|
|||
runMobileNetQuantizedTest(new Interpreter.Options().setNumThreads(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMobileNetQuantizedEnhancedCpu() {
|
||||
// The "enhanced CPU flag" should only impact float models, this is a sanity test to confirm.
|
||||
runMobileNetQuantizedTest(new Interpreter.Options().setUseXNNPACK(true));
|
||||
}
|
||||
|
||||
private static void runMobileNetFloatTest(Interpreter.Options options) {
|
||||
ByteBuffer img =
|
||||
TestUtils.getTestImageAsFloatByteBuffer(
|
||||
|
|
|
@ -409,6 +409,38 @@ public final class InterpreterTest {
|
|||
interpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUseXNNPACK() throws Exception {
|
||||
Interpreter interpreter =
|
||||
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
|
||||
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||
interpreter.run(fourD, parsedOutputs);
|
||||
float[] outputOneD = parsedOutputs[0][0][0];
|
||||
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
interpreter.close();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testResizeWithEnhancedCpuKernels() throws Exception {
|
||||
Interpreter interpreter =
|
||||
new Interpreter(MODEL_BUFFER, new Interpreter.Options().setUseXNNPACK(true));
|
||||
float[] input = {1.f};
|
||||
float[] output = new float[1];
|
||||
interpreter.run(input, output);
|
||||
assertThat(output).usingTolerance(0.1f).containsExactly(new float[] {3.f}).inOrder();
|
||||
|
||||
// The new input shape should trigger a resize. Inference should still work properly.
|
||||
float[] input2 = {1.f, 2.f};
|
||||
float[] output2 = new float[2];
|
||||
interpreter.run(input2, output2);
|
||||
assertThat(output2).usingTolerance(0.1f).containsExactly(new float[] {3.f, 6.f}).inOrder();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRedundantClose() throws Exception {
|
||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER);
|
||||
|
|
Loading…
Reference in New Issue