Prepare to make xnnpack delegate by default in TFLite Java binding.

PiperOrigin-RevId: 343778493
Change-Id: I4e5fa33d59a4989b7b0734b17a93567af9b90ce7
This commit is contained in:
Chao Mei 2020-11-22 20:16:04 -08:00 committed by TensorFlower Gardener
parent cdc17e37e2
commit e8474b8065
7 changed files with 79 additions and 20 deletions

View File

@ -178,6 +178,11 @@ public final class Interpreter implements AutoCloseable {
Boolean allowFp16PrecisionForFp32;
Boolean allowBufferHandleOutput;
Boolean allowCancellation;
// TODO(b/171856982): update the comment when applying XNNPACK delegate by default is
// enabled for C++ TfLite library on Android platform.
// Note: the initial "null" value indicates default behavior which may mean XNNPACK
// delegate will be applied by default.
Boolean useXNNPACK;
final List<Delegate> delegates = new ArrayList<>();
}

View File

@ -83,9 +83,17 @@ final class NativeInterpreterWrapper implements AutoCloseable {
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
}
applyDelegates(options);
// Simply use "-1" to represent the default mode.
int applyXNNPACKMode = -1;
if (options.useXNNPACK != null) {
useXNNPACK(
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
}
// TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
// enabled for C++ TfLite library on Android platform.
if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
}
allocateTensors(interpreterHandle, errorHandle);
this.isMemoryAllocated = true;
@ -459,7 +467,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
private static native void useXNNPACK(
long interpreterHandle, long errorHandle, boolean state, int numThreads);
long interpreterHandle, long errorHandle, int state, int numThreads);
private static native long createErrorReporter(int size);

View File

@ -21,12 +21,14 @@ limitations under the License.
namespace tflite {
// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
// builtin ops. For smaller binary sizes users should avoid linking this in, and
// should provide a custom make CreateOpResolver() instead.
// the tflite namespace. This one instantiates a
// BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but
// without applying any TfLite delegates by default (like the XNNPACK delegate).
// For smaller binary sizes users should avoid linking this in, and should
// provide a custom make CreateOpResolver() instead.
std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
new tflite::ops::builtin::BuiltinOpResolver());
new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
}
} // namespace tflite

View File

@ -33,9 +33,10 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
class BufferErrorReporter : public ErrorReporter {
public:
BufferErrorReporter(JNIEnv* env, int limit);
virtual ~BufferErrorReporter();
~BufferErrorReporter() override;
int Report(const char* format, va_list args) override;
const char* CachedErrorMessage();
using ErrorReporter::Report;
private:
char* buffer_;

View File

@ -319,10 +319,10 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jint state,
jint num_threads) {
// If not using xnnpack, simply don't apply the delegate.
if (!state) {
if (state == 0) {
return;
}
@ -369,6 +369,13 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
"Internal error: Failed to apply XNNPACK delegate: %s",
error_reporter->CachedErrorMessage());
}
} else if (state == -1) {
// Instead of throwing an exception, we tolerate the missing of such
// dependencies because we try to apply XNNPACK delegate by default.
TF_LITE_REPORT_ERROR(
error_reporter,
"WARNING: Missing necessary XNNPACK delegate dependencies to apply it "
"by default.\n");
} else {
ThrowException(env, tflite::jni::kIllegalArgumentException,
"Failed to load XNNPACK delegate from current runtime. "

View File

@ -484,6 +484,10 @@ public final class InterpreterTest {
fail();
} catch (IllegalStateException e) {
// Expected failure.
} catch (IllegalArgumentException e) {
// As we could apply some TfLite delegate by default, the flex ops preparation could fail if
// the flex delegate isn't applied first, in which this type of exception is thrown.
// Expected failure
}
}

View File

@ -104,7 +104,15 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH);
fail();
} catch (IllegalStateException e) {
assertThat(e).hasMessageThat().contains("Encountered unresolved custom op: Assign");
assertThat(e)
.hasMessageThat()
.contains("preparing tensor allocations: Encountered unresolved custom op: Assign");
} catch (IllegalArgumentException e) {
// As we could apply TfLite delegate by default, during which the prepration of this
// unresolved custom op could fail and this type of exception is thrown.
assertThat(e)
.hasMessageThat()
.containsMatch("Failed to apply .* delegate: Encountered unresolved custom op: Assign");
}
}
@ -201,8 +209,20 @@ public final class NativeInterpreterWrapperTest {
outputs.put(0, parsedOutputs);
wrapper.run(inputs, outputs);
long[] outputOneD = parsedOutputs[0][0][0];
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
long[] expected = {
-892834092L,
923423L,
2123918239018L,
-892834092L,
923423L,
2123918239018L,
-892834092L,
923423L,
2123918239018L,
-892834092L,
923423L,
2123918239018L
};
assertThat(outputOneD).isEqualTo(expected);
}
}
@ -222,8 +242,20 @@ public final class NativeInterpreterWrapperTest {
outputs.put(0, parsedOutputs);
wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
byte[] expected = {
(byte) 0xe0,
0x4f,
(byte) 0xd0,
(byte) 0xe0,
0x4f,
(byte) 0xd0,
(byte) 0xe0,
0x4f,
(byte) 0xd0,
(byte) 0xe0,
0x4f,
(byte) 0xd0
};
assertThat(outputOneD).isEqualTo(expected);
}
}
@ -242,7 +274,7 @@ public final class NativeInterpreterWrapperTest {
wrapper.run(inputs, outputs);
String[] outputOneD = parsedOutputs[0][0][0];
String[] expected = {
"s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
"s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
};
assertThat(outputOneD).isEqualTo(expected);
}
@ -276,8 +308,8 @@ public final class NativeInterpreterWrapperTest {
wrapper.run(inputs, outputs);
String[] outputOneD = parsedOutputs[0][0][0];
String[] expected = {
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
};
assertThat(outputOneD).isEqualTo(expected);
}
@ -332,8 +364,8 @@ public final class NativeInterpreterWrapperTest {
wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
};
assertThat(outputOneD).isEqualTo(expected);
}