Merge pull request #39527 from freedomtan:java_binding_for_allow_fp16_in_nnapi

PiperOrigin-RevId: 314196562
Change-Id: I67ed2b89e99d64523618310b17f0a0cc5754c9e7
This commit is contained in:
TensorFlower Gardener 2020-06-01 14:04:23 -07:00
commit a2ac929f45
5 changed files with 47 additions and 4 deletions

View File

@ -118,12 +118,24 @@ public class NnApiDelegate implements Delegate, AutoCloseable {
return this; return this;
} }
/**
* Enable or disable to allow fp32 computation to be run in fp16 in NNAPI. See
* https://source.android.com/devices/neural-networks#android-9
*
* <p>Only effective on Android 9 (API level 28) and above.
*/
public Options setAllowFp16(boolean enable) {
this.allowFp16 = enable;
return this;
}
private int executionPreference = EXECUTION_PREFERENCE_UNDEFINED; private int executionPreference = EXECUTION_PREFERENCE_UNDEFINED;
private String acceleratorName = null; private String acceleratorName = null;
private String cacheDir = null; private String cacheDir = null;
private String modelToken = null; private String modelToken = null;
private Integer maxDelegatedPartitions = null; private Integer maxDelegatedPartitions = null;
private Boolean useNnapiCpu = null; private Boolean useNnapiCpu = null;
private Boolean allowFp16 = null;
} }
public NnApiDelegate(Options options) { public NnApiDelegate(Options options) {
@ -139,7 +151,8 @@ public class NnApiDelegate implements Delegate, AutoCloseable {
/*overrideDisallowCpu=*/ options.useNnapiCpu != null, /*overrideDisallowCpu=*/ options.useNnapiCpu != null,
/*disallowCpuValue=*/ options.useNnapiCpu != null /*disallowCpuValue=*/ options.useNnapiCpu != null
? !options.useNnapiCpu.booleanValue() ? !options.useNnapiCpu.booleanValue()
: false); : false,
options.allowFp16 != null ? options.allowFp16 : false);
} }
public NnApiDelegate() { public NnApiDelegate() {
@ -204,7 +217,8 @@ public class NnApiDelegate implements Delegate, AutoCloseable {
String modelToken, String modelToken,
int maxDelegatedPartitions, int maxDelegatedPartitions,
boolean overrideDisallowCpu, boolean overrideDisallowCpu,
boolean disallowCpuValue); boolean disallowCpuValue,
boolean allowFp16);
private static native void deleteDelegate(long delegateHandle); private static native void deleteDelegate(long delegateHandle);

View File

@ -27,7 +27,8 @@ JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate( Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate(
JNIEnv* env, jclass clazz, jint preference, jstring accelerator_name, JNIEnv* env, jclass clazz, jint preference, jstring accelerator_name,
jstring cache_dir, jstring model_token, jint max_delegated_partitions, jstring cache_dir, jstring model_token, jint max_delegated_partitions,
jboolean override_disallow_cpu, jboolean disallow_cpu_value) { jboolean override_disallow_cpu, jboolean disallow_cpu_value,
jboolean allow_fp16) {
StatefulNnApiDelegate::Options options = StatefulNnApiDelegate::Options(); StatefulNnApiDelegate::Options options = StatefulNnApiDelegate::Options();
options.execution_preference = options.execution_preference =
(StatefulNnApiDelegate::Options::ExecutionPreference)preference; (StatefulNnApiDelegate::Options::ExecutionPreference)preference;
@ -49,6 +50,10 @@ Java_org_tensorflow_lite_nnapi_NnApiDelegate_createDelegate(
options.disallow_nnapi_cpu = disallow_cpu_value; options.disallow_nnapi_cpu = disallow_cpu_value;
} }
if (allow_fp16) {
options.allow_fp16 = allow_fp16;
}
auto delegate = new StatefulNnApiDelegate(options); auto delegate = new StatefulNnApiDelegate(options);
if (options.accelerator_name) { if (options.accelerator_name) {

View File

@ -102,8 +102,10 @@ public final class Interpreter implements AutoCloseable {
* Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
* (disallow). * (disallow).
* *
* <p>WARNING: This is an experimental API and subject to change. * @deprecated Prefer using {@link
* org.tensorflow.lite.nnapi.NnApiDelegate.Options#setAllowFp16(boolean enable)}.
*/ */
@Deprecated
public Options setAllowFp16PrecisionForFp32(boolean allow) { public Options setAllowFp16PrecisionForFp32(boolean allow) {
this.allowFp16PrecisionForFp32 = allow; this.allowFp16PrecisionForFp32 = allow;
return this; return this;

View File

@ -65,6 +65,7 @@ public final class InterpreterTest {
} }
@Test @Test
@SuppressWarnings("deprecation")
public void testInterpreterWithOptions() throws Exception { public void testInterpreterWithOptions() throws Exception {
Interpreter interpreter = Interpreter interpreter =
new Interpreter( new Interpreter(
@ -390,6 +391,7 @@ public final class InterpreterTest {
} }
@Test @Test
@SuppressWarnings("deprecation")
public void testTurnOnNNAPI() throws Exception { public void testTurnOnNNAPI() throws Exception {
Interpreter interpreter = Interpreter interpreter =
new Interpreter( new Interpreter(

View File

@ -56,6 +56,26 @@ public final class NnApiDelegateTest {
} }
} }
@Test
public void testInterpreterWithNnApiAllowFp16() throws Exception {
Interpreter.Options options = new Interpreter.Options();
NnApiDelegate.Options nnApiOptions = new NnApiDelegate.Options();
nnApiOptions.setAllowFp16(true);
try (NnApiDelegate delegate = new NnApiDelegate(nnApiOptions);
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
}
@Test @Test
public void testGetNnApiErrnoReturnsZeroIfNoNnapiCallFailed() throws Exception { public void testGetNnApiErrnoReturnsZeroIfNoNnapiCallFailed() throws Exception {
Interpreter.Options options = new Interpreter.Options(); Interpreter.Options options = new Interpreter.Options();