Prepare to make xnnpack delegate by default in TFLite Java binding.
PiperOrigin-RevId: 343778493 Change-Id: I4e5fa33d59a4989b7b0734b17a93567af9b90ce7
This commit is contained in:
parent
cdc17e37e2
commit
e8474b8065
@ -178,6 +178,11 @@ public final class Interpreter implements AutoCloseable {
|
||||
Boolean allowFp16PrecisionForFp32;
|
||||
Boolean allowBufferHandleOutput;
|
||||
Boolean allowCancellation;
|
||||
|
||||
// TODO(b/171856982): update the comment when applying XNNPACK delegate by default is
|
||||
// enabled for C++ TfLite library on Android platform.
|
||||
// Note: the initial "null" value indicates default behavior which may mean XNNPACK
|
||||
// delegate will be applied by default.
|
||||
Boolean useXNNPACK;
|
||||
final List<Delegate> delegates = new ArrayList<>();
|
||||
}
|
||||
|
@ -83,9 +83,17 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
|
||||
}
|
||||
applyDelegates(options);
|
||||
|
||||
// Simply use "-1" to represent the default mode.
|
||||
int applyXNNPACKMode = -1;
|
||||
if (options.useXNNPACK != null) {
|
||||
useXNNPACK(
|
||||
interpreterHandle, errorHandle, options.useXNNPACK.booleanValue(), options.numThreads);
|
||||
applyXNNPACKMode = options.useXNNPACK.booleanValue() ? 1 : 0;
|
||||
}
|
||||
|
||||
// TODO(b/171856982): uncomment the following when applying XNNPACK delegate by default is
|
||||
// enabled for C++ TfLite library on Android platform.
|
||||
if (applyXNNPACKMode == 1 /*|| applyXNNPACKMode == -1*/) {
|
||||
useXNNPACK(interpreterHandle, errorHandle, applyXNNPACKMode, options.numThreads);
|
||||
}
|
||||
allocateTensors(interpreterHandle, errorHandle);
|
||||
this.isMemoryAllocated = true;
|
||||
@ -459,7 +467,7 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
private static native void allowBufferHandleOutput(long interpreterHandle, boolean allow);
|
||||
|
||||
private static native void useXNNPACK(
|
||||
long interpreterHandle, long errorHandle, boolean state, int numThreads);
|
||||
long interpreterHandle, long errorHandle, int state, int numThreads);
|
||||
|
||||
private static native long createErrorReporter(int size);
|
||||
|
||||
|
@ -21,12 +21,14 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
|
||||
// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in
|
||||
// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the
|
||||
// builtin ops. For smaller binary sizes users should avoid linking this in, and
|
||||
// should provide a custom make CreateOpResolver() instead.
|
||||
// the tflite namespace. This one instantiates a
|
||||
// BuiltinOpResolverWithoutDefaultDelegates, with all the builtin ops but
|
||||
// without applying any TfLite delegates by default (like the XNNPACK delegate).
|
||||
// For smaller binary sizes users should avoid linking this in, and should
|
||||
// provide a custom make CreateOpResolver() instead.
|
||||
std::unique_ptr<OpResolver> CreateOpResolver() { // NOLINT
|
||||
return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
|
||||
new tflite::ops::builtin::BuiltinOpResolver());
|
||||
new tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -33,9 +33,10 @@ void ThrowException(JNIEnv* env, const char* clazz, const char* fmt, ...);
|
||||
class BufferErrorReporter : public ErrorReporter {
|
||||
public:
|
||||
BufferErrorReporter(JNIEnv* env, int limit);
|
||||
virtual ~BufferErrorReporter();
|
||||
~BufferErrorReporter() override;
|
||||
int Report(const char* format, va_list args) override;
|
||||
const char* CachedErrorMessage();
|
||||
using ErrorReporter::Report;
|
||||
|
||||
private:
|
||||
char* buffer_;
|
||||
|
@ -319,10 +319,10 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allowBufferHandleOutput(
|
||||
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
||||
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jboolean state,
|
||||
JNIEnv* env, jclass clazz, jlong handle, jlong error_handle, jint state,
|
||||
jint num_threads) {
|
||||
// If not using xnnpack, simply don't apply the delegate.
|
||||
if (!state) {
|
||||
if (state == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -369,6 +369,13 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useXNNPACK(
|
||||
"Internal error: Failed to apply XNNPACK delegate: %s",
|
||||
error_reporter->CachedErrorMessage());
|
||||
}
|
||||
} else if (state == -1) {
|
||||
// Instead of throwing an exception, we tolerate the missing of such
|
||||
// dependencies because we try to apply XNNPACK delegate by default.
|
||||
TF_LITE_REPORT_ERROR(
|
||||
error_reporter,
|
||||
"WARNING: Missing necessary XNNPACK delegate dependencies to apply it "
|
||||
"by default.\n");
|
||||
} else {
|
||||
ThrowException(env, tflite::jni::kIllegalArgumentException,
|
||||
"Failed to load XNNPACK delegate from current runtime. "
|
||||
|
@ -484,6 +484,10 @@ public final class InterpreterTest {
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
// Expected failure.
|
||||
} catch (IllegalArgumentException e) {
|
||||
// As we could apply some TfLite delegate by default, the flex ops preparation could fail if
|
||||
// the flex delegate isn't applied first, in which this type of exception is thrown.
|
||||
// Expected failure
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,15 @@ public final class NativeInterpreterWrapperTest {
|
||||
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH);
|
||||
fail();
|
||||
} catch (IllegalStateException e) {
|
||||
assertThat(e).hasMessageThat().contains("Encountered unresolved custom op: Assign");
|
||||
assertThat(e)
|
||||
.hasMessageThat()
|
||||
.contains("preparing tensor allocations: Encountered unresolved custom op: Assign");
|
||||
} catch (IllegalArgumentException e) {
|
||||
// As we could apply TfLite delegate by default, during which the prepration of this
|
||||
// unresolved custom op could fail and this type of exception is thrown.
|
||||
assertThat(e)
|
||||
.hasMessageThat()
|
||||
.containsMatch("Failed to apply .* delegate: Encountered unresolved custom op: Assign");
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,8 +209,20 @@ public final class NativeInterpreterWrapperTest {
|
||||
outputs.put(0, parsedOutputs);
|
||||
wrapper.run(inputs, outputs);
|
||||
long[] outputOneD = parsedOutputs[0][0][0];
|
||||
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
|
||||
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
|
||||
long[] expected = {
|
||||
-892834092L,
|
||||
923423L,
|
||||
2123918239018L,
|
||||
-892834092L,
|
||||
923423L,
|
||||
2123918239018L,
|
||||
-892834092L,
|
||||
923423L,
|
||||
2123918239018L,
|
||||
-892834092L,
|
||||
923423L,
|
||||
2123918239018L
|
||||
};
|
||||
assertThat(outputOneD).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
@ -222,8 +242,20 @@ public final class NativeInterpreterWrapperTest {
|
||||
outputs.put(0, parsedOutputs);
|
||||
wrapper.run(inputs, outputs);
|
||||
byte[] outputOneD = parsedOutputs[0][0][0];
|
||||
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
|
||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
|
||||
byte[] expected = {
|
||||
(byte) 0xe0,
|
||||
0x4f,
|
||||
(byte) 0xd0,
|
||||
(byte) 0xe0,
|
||||
0x4f,
|
||||
(byte) 0xd0,
|
||||
(byte) 0xe0,
|
||||
0x4f,
|
||||
(byte) 0xd0,
|
||||
(byte) 0xe0,
|
||||
0x4f,
|
||||
(byte) 0xd0
|
||||
};
|
||||
assertThat(outputOneD).isEqualTo(expected);
|
||||
}
|
||||
}
|
||||
@ -242,7 +274,7 @@ public final class NativeInterpreterWrapperTest {
|
||||
wrapper.run(inputs, outputs);
|
||||
String[] outputOneD = parsedOutputs[0][0][0];
|
||||
String[] expected = {
|
||||
"s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
|
||||
"s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333", "s1", "s22", "s333"
|
||||
};
|
||||
assertThat(outputOneD).isEqualTo(expected);
|
||||
}
|
||||
@ -276,8 +308,8 @@ public final class NativeInterpreterWrapperTest {
|
||||
wrapper.run(inputs, outputs);
|
||||
String[] outputOneD = parsedOutputs[0][0][0];
|
||||
String[] expected = {
|
||||
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
|
||||
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
|
||||
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e",
|
||||
"\uD800\uDC01", "s22", "\ud841\udf0e", "\uD800\uDC01", "s22", "\ud841\udf0e"
|
||||
};
|
||||
assertThat(outputOneD).isEqualTo(expected);
|
||||
}
|
||||
@ -332,8 +364,8 @@ public final class NativeInterpreterWrapperTest {
|
||||
wrapper.run(inputs, outputs);
|
||||
byte[] outputOneD = parsedOutputs[0][0][0];
|
||||
byte[] expected = {
|
||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
|
||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
|
||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
|
||||
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0
|
||||
};
|
||||
assertThat(outputOneD).isEqualTo(expected);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user