Merge pull request #31699 from jdduke/cherrypicks_5UP6F

Ensure native libs are loaded when using NnApiDelegate
This commit is contained in:
Goldie Gadde 2019-08-16 14:19:53 -07:00 committed by GitHub
commit bdbaf055f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 101 additions and 4 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.lite.nnapi;
import org.tensorflow.lite.Delegate;
import org.tensorflow.lite.TensorFlowLite;
/** {@link Delegate} for NNAPI inference. */
public class NnApiDelegate implements Delegate, AutoCloseable {
@ -44,4 +45,9 @@ public class NnApiDelegate implements Delegate, AutoCloseable {
}
private static native long createDelegate();
static {
// Ensure the native TensorFlow Lite libraries are available.
TensorFlowLite.init();
}
}

View File

@ -195,6 +195,27 @@ java_test(
],
)
java_test(
name = "NnApiDelegateTest",
size = "small",
srcs = [
"src/test/java/org/tensorflow/lite/TestUtils.java",
"src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java",
],
data = [
"src/testdata/add.bin",
],
javacopts = JAVACOPTS,
tags = ["no_mac"],
test_class = "org.tensorflow.lite.nnapi.NnApiDelegateTest",
visibility = ["//visibility:private"],
deps = [
":tensorflowlitelib",
"@com_google_truth",
"@junit",
],
)
java_test(
name = "InterpreterFlexTest",
size = "small",
@ -244,6 +265,7 @@ filegroup(
srcs = [
"src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java",
"src/test/java/org/tensorflow/lite/InterpreterTest.java",
"src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java",
],
visibility = ["//visibility:public"],
)

View File

@ -22,6 +22,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.tensorflow.lite.nnapi.NnApiDelegate;
/**
* An internal wrapper that wraps native interpreter and controls model execution.
@ -69,9 +70,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
this.inputTensors = new Tensor[getInputCount(interpreterHandle)];
this.outputTensors = new Tensor[getOutputCount(interpreterHandle)];
if (options.useNNAPI != null) {
setUseNNAPI(options.useNNAPI.booleanValue());
}
if (options.allowFp16PrecisionForFp32 != null) {
allowFp16PrecisionForFp32(
interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue());
@ -79,6 +77,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
if (options.allowBufferHandleOutput != null) {
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
}
if (options.useNNAPI != null && options.useNNAPI.booleanValue()) {
optionalNnApiDelegate = new NnApiDelegate();
applyDelegate(interpreterHandle, errorHandle, optionalNnApiDelegate.getNativeHandle());
}
for (Delegate delegate : options.delegates) {
applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
delegates.add(delegate);
@ -112,6 +114,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
outputsIndexes = null;
isMemoryAllocated = false;
delegates.clear();
if (optionalNnApiDelegate != null) {
optionalNnApiDelegate.close();
optionalNnApiDelegate = null;
}
}
/** Sets inputs, runs model inference and returns outputs. */
@ -345,6 +351,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
// delegates for safety.
private final List<Delegate> delegates = new ArrayList<>();
// Prefer using the NnApiDelegate directly rather than the deprecated useNNNAPI() method when
// NNAPI is enabled via Interpreter.Options.
private NnApiDelegate optionalNnApiDelegate;
private static native long allocateTensors(long interpreterHandle, long errorHandle);
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);

View File

@ -47,8 +47,10 @@ public final class TensorFlowLite {
/**
* Load the TensorFlowLite runtime C library.
*
* @hide
*/
static boolean init() {
public static boolean init() {
Throwable primaryLibException;
try {
System.loadLibrary(PRIMARY_LIBNAME);

View File

@ -0,0 +1,57 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite.nnapi;
import static com.google.common.truth.Truth.assertThat;
import java.nio.ByteBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.TestUtils;
/** Unit tests for {@link org.tensorflow.lite.nnapi.NnApiDelegate}. */
@RunWith(JUnit4.class)
public final class NnApiDelegateTest {
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
@Test
public void testBasic() throws Exception {
try (NnApiDelegate delegate = new NnApiDelegate()) {
assertThat(delegate.getNativeHandle()).isNotEqualTo(0);
}
}
@Test
public void testInterpreterWithNnApi() throws Exception {
Interpreter.Options options = new Interpreter.Options();
try (NnApiDelegate delegate = new NnApiDelegate();
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
float[][][][] parsedOutputs = new float[2][8][8][3];
interpreter.run(fourD, parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, 19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
}
}
}