Prepare to make xnnpack delegate by default in TFLite Java binding.

PiperOrigin-RevId: 343778493
Change-Id: I4e5fa33d59a4989b7b0734b17a93567af9b90ce7
This commit is contained in:
Chao Mei 2020-11-22 20:16:04 -08:00 committed by TensorFlower Gardener
parent cdc17e37e2
commit e8474b8065
7 changed files with 79 additions and 20 deletions

View File

@ -178,6 +178,11 @@ public final class Interpreter implements AutoCloseable {
Boolean allowFp16PrecisionForFp32; Boolean allowFp16PrecisionForFp32;
Boolean allowBufferHandleOutput; Boolean allowBufferHandleOutput;
Boolean allowCancellation; Boolean allowCancellation;
// TODO(b/171856982): update the comment when applying XNNPACK delegate by default is
// enabled for C++ TfLite library on Android platform.
// Note: the initial "null" value indicates default behavior which may mean XNNPACK
// delegate will be applied by default.
Boolean useXNNPACK; Boolean useXNNPACK;
final List<Delegate> delegates = new ArrayList<>(); final List<Delegate> delegates = new ArrayList<>();
} }

View File

@ -83,9 +83,17 @@ final class NativeInterpreterWrapper implements AutoCloseable {
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue()); allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
} }
applyDelegates(options); applyDelegates(options);
// Simply use "-1" to represent the default mode.
int applyXNNPACKMode = -1;
if (options.useXNNPACK != null) { if (options.useXNNPACK != null) {
useXNNPACK( applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads); }
// TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
// enabled for C++ TfLite library on Android platform.
if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
} }
allocateTensors(interpreterHandle, errorHandle); allocateTensors(interpreterHandle, errorHandle);
this.isMemoryAllocated = true; this.isMemoryAllocated = true;
@ -459,7 +467,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow); private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
private static native void useXNNPACK( private static native void useXNNPACK(
long interpreterHandle, long errorHandle, boolean state, int numThreads); long interpreterHandle, long errorHandle, int state, int numThreads);
private static native long createErrorReporter(int size); private static native long createErrorReporter(int size);

View File

@ -21,12 +21,14 @@ limitations under the License.
namespace tflite { namespace tflite {
// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in // The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the // the tflite namespace. This one instantiates a
// builtin ops. For smaller binary sizes users should avoid linking this in, and // BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but
// should provide a custom make CreateOpResolver() instead. // without applying any TfLite delegates by default (like the XNNPACK delegate).
// For smaller binary sizes users should avoid linking this in, and should
// provide a custom make CreateOpResolver() instead.
std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>( return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
new tflite::ops::builtin::BuiltinOpResolver()); new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
} }
} // namespace tflite } // namespace tflite

View File

@ -33,9 +33,10 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
class BufferErrorReporter : public ErrorReporter { class BufferErrorReporter : public ErrorReporter {
public: public:
BufferErrorReporter(JNIEnv* env, int limit); BufferErrorReporter(JNIEnv* env, int limit);
virtual ~BufferErrorReporter(); ~BufferErrorReporter() override;
int Report(const char* format, va_list args) override; int Report(const char* format, va_list args) override;
const char* CachedErrorMessage(); const char* CachedErrorMessage();
using ErrorReporter::Report;
private: private:
char* buffer_; char* buffer_;

View File

@ -319,10 +319,10 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK( Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state, JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jint state,
jint num_threads) { jint num_threads) {
// If not using xnnpack, simply don't apply the delegate. // If not using xnnpack, simply don't apply the delegate.
if (!state) { if (state == 0) {
return; return;
} }
@ -369,6 +369,13 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
"Internal error: Failed to apply XNNPACK delegate: %s", "Internal error: Failed to apply XNNPACK delegate: %s",
error_reporter->CachedErrorMessage()); error_reporter->CachedErrorMessage());
} }
} else if (state == -1) {
// Instead of throwing an exception, we tolerate the missing of such
// dependencies because we try to apply XNNPACK delegate by default.
TF_LITE_REPORT_ERROR(
error_reporter,
"WARNING: Missing necessary XNNPACK delegate dependencies to apply it "
"by default.\n");
} else { } else {
ThrowException(env, tflite::jni::kIllegalArgumentException, ThrowException(env, tflite::jni::kIllegalArgumentException,
"Failed to load XNNPACK delegate from current runtime. " "Failed to load XNNPACK delegate from current runtime. "

View File

@ -484,6 +484,10 @@ public final class InterpreterTest {
fail(); fail();
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
// Expected failure. // Expected failure.
} catch (IllegalArgumentException e) {
// As we could apply some TfLite delegate by default, the flex ops preparation could fail if
// the flex delegate isn't applied first, in which this type of exception is thrown.
// Expected failure
} }
} }

View File

@ -104,7 +104,15 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH); NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH);
fail(); fail();
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
assertThat(e).hasMessageThat().contains("Encountered unresolved custom op: Assign"); assertThat(e)
.hasMessageThat()
.contains("preparing tensor allocations: Encountered unresolved custom op: Assign");
} catch (IllegalArgumentException e) {
// As we could apply TfLite delegate by default, during which the prepration of this
// unresolved custom op could fail and this type of exception is thrown.
assertThat(e)
.hasMessageThat()
.containsMatch("Failed to apply .* delegate: Encountered unresolved custom op: Assign");
} }
} }
@ -201,8 +209,20 @@ public final class NativeInterpreterWrapperTest {
outputs.put(0, parsedOutputs); outputs.put(0, parsedOutputs);
wrapper.run(inputs, outputs); wrapper.run(inputs, outputs);
long[] outputOneD = parsedOutputs[0][0][0]; long[] outputOneD = parsedOutputs[0][0][0];
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, long[] expected = {
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; -892834092L,
923423L,
2123918239018L,
-892834092L,
923423L,
2123918239018L,
-892834092L,
923423L,
2123918239018L,
-892834092L,
923423L,
2123918239018L
};
assertThat(outputOneD).isEqualTo(expected); assertThat(outputOneD).isEqualTo(expected);
} }
} }
@ -222,8 +242,20 @@ public final class NativeInterpreterWrapperTest {
outputs.put(0, parsedOutputs); outputs.put(0, parsedOutputs);
wrapper.run(inputs, outputs); wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0]; byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; (byte) 0xe0,
0x4f,
(byte) 0xd0,
(byte) 0xe0,
0x4f,
(byte) 0xd0,
(byte) 0xe0,
0x4f,
(byte) 0xd0,
(byte) 0xe0,
0x4f,
(byte) 0xd0
};
assertThat(outputOneD).isEqualTo(expected); assertThat(outputOneD).isEqualTo(expected);
} }
} }
@ -242,7 +274,7 @@ public final class NativeInterpreterWrapperTest {
wrapper.run(inputs, outputs); wrapper.run(inputs, outputs);
String[] outputOneD = parsedOutputs[0][0][0]; String[] outputOneD = parsedOutputs[0][0][0];
String[] expected = { String[] expected = {
"s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333" "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
}; };
assertThat(outputOneD).isEqualTo(expected); assertThat(outputOneD).isEqualTo(expected);
} }
@ -276,8 +308,8 @@ public final class NativeInterpreterWrapperTest {
wrapper.run(inputs, outputs); wrapper.run(inputs, outputs);
String[] outputOneD = parsedOutputs[0][0][0]; String[] outputOneD = parsedOutputs[0][0][0];
String[] expected = { String[] expected = {
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e" "\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
}; };
assertThat(outputOneD).isEqualTo(expected); assertThat(outputOneD).isEqualTo(expected);
} }
@ -332,8 +364,8 @@ public final class NativeInterpreterWrapperTest {
wrapper.run(inputs, outputs); wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0]; byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = { byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0 (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
}; };
assertThat(outputOneD).isEqualTo(expected); assertThat(outputOneD).isEqualTo(expected);
} }