Merge pull request #31699 from jdduke/cherrypicks_5UP6F
Ensure native libs are loaded when using NnApiDelegate
This commit is contained in:
commit
bdbaf055f1
@ -16,6 +16,7 @@ limitations under the License.
|
||||
package org.tensorflow.lite.nnapi;
|
||||
|
||||
import org.tensorflow.lite.Delegate;
|
||||
import org.tensorflow.lite.TensorFlowLite;
|
||||
|
||||
/** {@link Delegate} for NNAPI inference. */
|
||||
public class NnApiDelegate implements Delegate, AutoCloseable {
|
||||
@ -44,4 +45,9 @@ public class NnApiDelegate implements Delegate, AutoCloseable {
|
||||
}
|
||||
|
||||
private static native long createDelegate();
|
||||
|
||||
static {
|
||||
// Ensure the native TensorFlow Lite libraries are available.
|
||||
TensorFlowLite.init();
|
||||
}
|
||||
}
|
||||
|
@ -195,6 +195,27 @@ java_test(
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "NnApiDelegateTest",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/TestUtils.java",
|
||||
"src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java",
|
||||
],
|
||||
data = [
|
||||
"src/testdata/add.bin",
|
||||
],
|
||||
javacopts = JAVACOPTS,
|
||||
tags = ["no_mac"],
|
||||
test_class = "org.tensorflow.lite.nnapi.NnApiDelegateTest",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":tensorflowlitelib",
|
||||
"@com_google_truth",
|
||||
"@junit",
|
||||
],
|
||||
)
|
||||
|
||||
java_test(
|
||||
name = "InterpreterFlexTest",
|
||||
size = "small",
|
||||
@ -244,6 +265,7 @@ filegroup(
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/InterpreterMobileNetTest.java",
|
||||
"src/test/java/org/tensorflow/lite/InterpreterTest.java",
|
||||
"src/test/java/org/tensorflow/lite/nnapi/NnApiDelegateTest.java",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
@ -22,6 +22,7 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.tensorflow.lite.nnapi.NnApiDelegate;
|
||||
|
||||
/**
|
||||
* An internal wrapper that wraps native interpreter and controls model execution.
|
||||
@ -69,9 +70,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
this.interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
|
||||
this.inputTensors = new Tensor[getInputCount(interpreterHandle)];
|
||||
this.outputTensors = new Tensor[getOutputCount(interpreterHandle)];
|
||||
if (options.useNNAPI != null) {
|
||||
setUseNNAPI(options.useNNAPI.booleanValue());
|
||||
}
|
||||
if (options.allowFp16PrecisionForFp32 != null) {
|
||||
allowFp16PrecisionForFp32(
|
||||
interpreterHandle, options.allowFp16PrecisionForFp32.booleanValue());
|
||||
@ -79,6 +77,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
if (options.allowBufferHandleOutput != null) {
|
||||
allowBufferHandleOutput(interpreterHandle, options.allowBufferHandleOutput.booleanValue());
|
||||
}
|
||||
if (options.useNNAPI != null && options.useNNAPI.booleanValue()) {
|
||||
optionalNnApiDelegate = new NnApiDelegate();
|
||||
applyDelegate(interpreterHandle, errorHandle, optionalNnApiDelegate.getNativeHandle());
|
||||
}
|
||||
for (Delegate delegate : options.delegates) {
|
||||
applyDelegate(interpreterHandle, errorHandle, delegate.getNativeHandle());
|
||||
delegates.add(delegate);
|
||||
@ -112,6 +114,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
outputsIndexes = null;
|
||||
isMemoryAllocated = false;
|
||||
delegates.clear();
|
||||
if (optionalNnApiDelegate != null) {
|
||||
optionalNnApiDelegate.close();
|
||||
optionalNnApiDelegate = null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Sets inputs, runs model inference and returns outputs. */
|
||||
@ -345,6 +351,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||
// delegates for safety.
|
||||
private final List<Delegate> delegates = new ArrayList<>();
|
||||
|
||||
// Prefer using the NnApiDelegate directly rather than the deprecated useNNNAPI() method when
|
||||
// NNAPI is enabled via Interpreter.Options.
|
||||
private NnApiDelegate optionalNnApiDelegate;
|
||||
|
||||
private static native long allocateTensors(long interpreterHandle, long errorHandle);
|
||||
|
||||
private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
|
||||
|
@ -47,8 +47,10 @@ public final class TensorFlowLite {
|
||||
|
||||
/**
|
||||
* Load the TensorFlowLite runtime C library.
|
||||
*
|
||||
* @hide
|
||||
*/
|
||||
static boolean init() {
|
||||
public static boolean init() {
|
||||
Throwable primaryLibException;
|
||||
try {
|
||||
System.loadLibrary(PRIMARY_LIBNAME);
|
||||
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite.nnapi;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.lite.Interpreter;
|
||||
import org.tensorflow.lite.TestUtils;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.lite.nnapi.NnApiDelegate}. */
|
||||
@RunWith(JUnit4.class)
|
||||
public final class NnApiDelegateTest {
|
||||
|
||||
private static final String MODEL_PATH = "tensorflow/lite/java/src/testdata/add.bin";
|
||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||
|
||||
@Test
|
||||
public void testBasic() throws Exception {
|
||||
try (NnApiDelegate delegate = new NnApiDelegate()) {
|
||||
assertThat(delegate.getNativeHandle()).isNotEqualTo(0);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInterpreterWithNnApi() throws Exception {
|
||||
Interpreter.Options options = new Interpreter.Options();
|
||||
try (NnApiDelegate delegate = new NnApiDelegate();
|
||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||
float[] oneD = {1.23f, 6.54f, 7.81f};
|
||||
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
|
||||
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
|
||||
float[][][][] fourD = {threeD, threeD};
|
||||
float[][][][] parsedOutputs = new float[2][8][8][3];
|
||||
interpreter.run(fourD, parsedOutputs);
|
||||
float[] outputOneD = parsedOutputs[0][0][0];
|
||||
float[] expected = {3.69f, 19.62f, 23.43f};
|
||||
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user