Prepare to make xnnpack delegate by default in TFLite Java binding.
PiperOrigin-RevId: 343778493 Change-Id: I4e5fa33d59a4989b7b0734b17a93567af9b90ce7
This commit is contained in:
parent
cdc17e37e2
commit
e8474b8065
@ -178,6 +178,11 @@ public final class Interpreter implements AutoCloseable {
|
|||||||
Boolean allowFp16PrecisionForFp32;
|
Boolean allowFp16PrecisionForFp32;
|
||||||
Boolean allowBufferHandleOutput;
|
Boolean allowBufferHandleOutput;
|
||||||
Boolean allowCancellation;
|
Boolean allowCancellation;
|
||||||
|
|
||||||
|
// TODO(b/171856982): update the comment when applying XNNPACK delegate by default is
|
||||||
|
// enabled for C++ TfLite library on Android platform.
|
||||||
|
// Note: the initial "null" value indicates default behavior which may mean XNNPACK
|
||||||
|
// delegate will be applied by default.
|
||||||
Boolean useXNNPACK;
|
Boolean useXNNPACK;
|
||||||
final List<Delegate> delegates = new ArrayList<>();
|
final List<Delegate> delegates = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -83,9 +83,17 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
|||||||
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
|
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
|
||||||
}
|
}
|
||||||
applyDelegates(options);
|
applyDelegates(options);
|
||||||
|
|
||||||
|
// Simply use "-1" to represent the default mode.
|
||||||
|
int applyXNNPACKMode = -1;
|
||||||
if (options.useXNNPACK != null) {
|
if (options.useXNNPACK != null) {
|
||||||
useXNNPACK(
|
applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
|
||||||
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
|
}
|
||||||
|
|
||||||
|
// TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
|
||||||
|
// enabled for C++ TfLite library on Android platform.
|
||||||
|
if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
|
||||||
|
useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
|
||||||
}
|
}
|
||||||
allocateTensors(interpreterHandle, errorHandle);
|
allocateTensors(interpreterHandle, errorHandle);
|
||||||
this.isMemoryAllocated = true;
|
this.isMemoryAllocated = true;
|
||||||
@ -459,7 +467,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
|||||||
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
|
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
|
||||||
|
|
||||||
private static native void useXNNPACK(
|
private static native void useXNNPACK(
|
||||||
long interpreterHandle, long errorHandle, boolean state, int numThreads);
|
long interpreterHandle, long errorHandle, int state, int numThreads);
|
||||||
|
|
||||||
private static native long createErrorReporter(int size);
|
private static native long createErrorReporter(int size);
|
||||||
|
|
||||||
|
|||||||
@ -21,12 +21,14 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
|
// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
|
||||||
// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
|
// the tflite namespace. This one instantiates a
|
||||||
// builtin ops. For smaller binary sizes users should avoid linking this in, and
|
// BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but
|
||||||
// should provide a custom make CreateOpResolver() instead.
|
// without applying any TfLite delegates by default (like the XNNPACK delegate).
|
||||||
|
// For smaller binary sizes users should avoid linking this in, and should
|
||||||
|
// provide a custom make CreateOpResolver() instead.
|
||||||
std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
|
std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
|
||||||
return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
|
return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
|
||||||
new tflite::ops::builtin::BuiltinOpResolver());
|
new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|||||||
@ -33,9 +33,10 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
|
|||||||
class BufferErrorReporter : public ErrorReporter {
|
class BufferErrorReporter : public ErrorReporter {
|
||||||
public:
|
public:
|
||||||
BufferErrorReporter(JNIEnv* env, int limit);
|
BufferErrorReporter(JNIEnv* env, int limit);
|
||||||
virtual ~BufferErrorReporter();
|
~BufferErrorReporter() override;
|
||||||
int Report(const char* format, va_list args) override;
|
int Report(const char* format, va_list args) override;
|
||||||
const char* CachedErrorMessage();
|
const char* CachedErrorMessage();
|
||||||
|
using ErrorReporter::Report;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
char* buffer_;
|
char* buffer_;
|
||||||
|
|||||||
@ -319,10 +319,10 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
|
|||||||
|
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
||||||
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
|
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jint state,
|
||||||
jint num_threads) {
|
jint num_threads) {
|
||||||
// If not using xnnpack, simply don't apply the delegate.
|
// If not using xnnpack, simply don't apply the delegate.
|
||||||
if (!state) {
|
if (state == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,6 +369,13 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
|||||||
"Internal error: Failed to apply XNNPACK delegate: %s",
|
"Internal error: Failed to apply XNNPACK delegate: %s",
|
||||||
error_reporter->CachedErrorMessage());
|
error_reporter->CachedErrorMessage());
|
||||||
}
|
}
|
||||||
|
} else if (state == -1) {
|
||||||
|
// Instead of throwing an exception, we tolerate the missing of such
|
||||||
|
// dependencies because we try to apply XNNPACK delegate by default.
|
||||||
|
TF_LITE_REPORT_ERROR(
|
||||||
|
error_reporter,
|
||||||
|
"WARNING: Missing necessary XNNPACK delegate dependencies to apply it "
|
||||||
|
"by default.\n");
|
||||||
} else {
|
} else {
|
||||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||||
"Failed to load XNNPACK delegate from current runtime. "
|
"Failed to load XNNPACK delegate from current runtime. "
|
||||||
|
|||||||
@ -484,6 +484,10 @@ public final class InterpreterTest {
|
|||||||
fail();
|
fail();
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
// Expected failure.
|
// Expected failure.
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
// As we could apply some TfLite delegate by default, the flex ops preparation could fail if
|
||||||
|
// the flex delegate isn't applied first, in which this type of exception is thrown.
|
||||||
|
// Expected failure
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -104,7 +104,15 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH);
|
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH);
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
assertThat(e).hasMessageThat().contains("Encountered unresolved custom op: Assign");
|
assertThat(e)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("preparing tensor allocations: Encountered unresolved custom op: Assign");
|
||||||
|
} catch (IllegalArgumentException e) {
|
||||||
|
// As we could apply TfLite delegate by default, during which the prepration of this
|
||||||
|
// unresolved custom op could fail and this type of exception is thrown.
|
||||||
|
assertThat(e)
|
||||||
|
.hasMessageThat()
|
||||||
|
.containsMatch("Failed to apply .* delegate: Encountered unresolved custom op: Assign");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,8 +209,20 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
outputs.put(0, parsedOutputs);
|
outputs.put(0, parsedOutputs);
|
||||||
wrapper.run(inputs, outputs);
|
wrapper.run(inputs, outputs);
|
||||||
long[] outputOneD = parsedOutputs[0][0][0];
|
long[] outputOneD = parsedOutputs[0][0][0];
|
||||||
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
|
long[] expected = {
|
||||||
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
|
-892834092L,
|
||||||
|
923423L,
|
||||||
|
2123918239018L,
|
||||||
|
-892834092L,
|
||||||
|
923423L,
|
||||||
|
2123918239018L,
|
||||||
|
-892834092L,
|
||||||
|
923423L,
|
||||||
|
2123918239018L,
|
||||||
|
-892834092L,
|
||||||
|
923423L,
|
||||||
|
2123918239018L
|
||||||
|
};
|
||||||
assertThat(outputOneD).isEqualTo(expected);
|
assertThat(outputOneD).isEqualTo(expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -222,8 +242,20 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
outputs.put(0, parsedOutputs);
|
outputs.put(0, parsedOutputs);
|
||||||
wrapper.run(inputs, outputs);
|
wrapper.run(inputs, outputs);
|
||||||
byte[] outputOneD = parsedOutputs[0][0][0];
|
byte[] outputOneD = parsedOutputs[0][0][0];
|
||||||
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
|
byte[] expected = {
|
||||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
|
(byte) 0xe0,
|
||||||
|
0x4f,
|
||||||
|
(byte) 0xd0,
|
||||||
|
(byte) 0xe0,
|
||||||
|
0x4f,
|
||||||
|
(byte) 0xd0,
|
||||||
|
(byte) 0xe0,
|
||||||
|
0x4f,
|
||||||
|
(byte) 0xd0,
|
||||||
|
(byte) 0xe0,
|
||||||
|
0x4f,
|
||||||
|
(byte) 0xd0
|
||||||
|
};
|
||||||
assertThat(outputOneD).isEqualTo(expected);
|
assertThat(outputOneD).isEqualTo(expected);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -242,7 +274,7 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
wrapper.run(inputs, outputs);
|
wrapper.run(inputs, outputs);
|
||||||
String[] outputOneD = parsedOutputs[0][0][0];
|
String[] outputOneD = parsedOutputs[0][0][0];
|
||||||
String[] expected = {
|
String[] expected = {
|
||||||
"s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
|
"s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
|
||||||
};
|
};
|
||||||
assertThat(outputOneD).isEqualTo(expected);
|
assertThat(outputOneD).isEqualTo(expected);
|
||||||
}
|
}
|
||||||
@ -276,8 +308,8 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
wrapper.run(inputs, outputs);
|
wrapper.run(inputs, outputs);
|
||||||
String[] outputOneD = parsedOutputs[0][0][0];
|
String[] outputOneD = parsedOutputs[0][0][0];
|
||||||
String[] expected = {
|
String[] expected = {
|
||||||
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
|
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
|
||||||
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
|
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
|
||||||
};
|
};
|
||||||
assertThat(outputOneD).isEqualTo(expected);
|
assertThat(outputOneD).isEqualTo(expected);
|
||||||
}
|
}
|
||||||
@ -332,8 +364,8 @@ public final class NativeInterpreterWrapperTest {
|
|||||||
wrapper.run(inputs, outputs);
|
wrapper.run(inputs, outputs);
|
||||||
byte[] outputOneD = parsedOutputs[0][0][0];
|
byte[] outputOneD = parsedOutputs[0][0][0];
|
||||||
byte[] expected = {
|
byte[] expected = {
|
||||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
|
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
|
||||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
|
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
|
||||||
};
|
};
|
||||||
assertThat(outputOneD).isEqualTo(expected);
|
assertThat(outputOneD).isEqualTo(expected);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user