[tf.lite] Adds a setQuantizedModelsAllowed() Java API for running quant models with GPU delegate
PiperOrigin-RevId: 311402449 Change-Id: I49809a004ad11c4bc9d9e5272472f3b85ea7948f
This commit is contained in:
parent
ffb7db1a81
commit
062cf92d06
|
@ -62,6 +62,18 @@ public class GpuDelegate implements Delegate, Closeable {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enables running quantized models with the delegate. Defaults to false.
|
||||
*
|
||||
* <p>WARNING: This is an experimental API and subject to change.
|
||||
*
|
||||
* @param quantizedModelsAllowed When {@code true}, the GPU may run quantized models.
|
||||
*/
|
||||
public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) {
|
||||
this.quantizedModelsAllowed = quantizedModelsAllowed;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the inference preference for precision/compilation/runtime tradeoffs.
|
||||
*
|
||||
|
@ -74,11 +86,16 @@ public class GpuDelegate implements Delegate, Closeable {
|
|||
}
|
||||
|
||||
boolean precisionLossAllowed = true;
|
||||
boolean quantizedModelsAllowed = false;
|
||||
int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
|
||||
}
|
||||
|
||||
public GpuDelegate(Options options) {
|
||||
delegateHandle = createDelegate(options.precisionLossAllowed, options.inferencePreference);
|
||||
delegateHandle =
|
||||
createDelegate(
|
||||
options.precisionLossAllowed,
|
||||
options.quantizedModelsAllowed,
|
||||
options.inferencePreference);
|
||||
}
|
||||
|
||||
public GpuDelegate() {
|
||||
|
@ -107,7 +124,8 @@ public class GpuDelegate implements Delegate, Closeable {
|
|||
System.loadLibrary(TFLITE_GPU_LIB);
|
||||
}
|
||||
|
||||
private static native long createDelegate(boolean precisionLossAllowed, int preference);
|
||||
private static native long createDelegate(
|
||||
boolean precisionLossAllowed, boolean quantizedModelsAllowed, int preference);
|
||||
|
||||
private static native void deleteDelegate(long delegateHandle);
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ extern "C" {
|
|||
|
||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
|
||||
JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
|
||||
jint inference_preference) {
|
||||
jboolean quantized_models_allowed, jint inference_preference) {
|
||||
TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default();
|
||||
if (precision_loss_allowed == JNI_TRUE) {
|
||||
options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
|
||||
|
@ -31,6 +31,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
|
|||
TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE;
|
||||
options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
|
||||
}
|
||||
options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE;
|
||||
if (quantized_models_allowed) {
|
||||
options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
|
||||
}
|
||||
options.inference_preference = static_cast<int32_t>(inference_preference);
|
||||
return reinterpret_cast<jlong>(TfLiteGpuDelegateV2Create(&options));
|
||||
}
|
||||
|
|
|
@ -353,6 +353,7 @@ filegroup(
|
|||
filegroup(
|
||||
name = "portable_gpu_tests",
|
||||
srcs = [
|
||||
"src/test/java/org/tensorflow/lite/InterpreterTestHelper.java",
|
||||
"src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
|
|
|
@ -491,6 +491,11 @@ public final class Interpreter implements AutoCloseable {
|
|||
wrapper.resetVariableTensors();
|
||||
}
|
||||
|
||||
int getExecutionPlanLength() {
|
||||
checkNotClosed();
|
||||
return wrapper.getExecutionPlanLength();
|
||||
}
|
||||
|
||||
/** Release resources associated with the {@code Interpreter}. */
|
||||
@Override
|
||||
public void close() {
|
||||
|
|
|
@ -324,6 +324,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
|||
return outputTensor;
|
||||
}
|
||||
|
||||
/** Gets the number of ops in the execution plan. */
|
||||
int getExecutionPlanLength() {
|
||||
return getExecutionPlanLength(interpreterHandle);
|
||||
}
|
||||
|
||||
private void applyDelegates(Interpreter.Options options) {
|
||||
// First apply the flex delegate if necessary. This ensures the graph is fully resolved before
|
||||
// applying other delegates.
|
||||
|
@ -419,6 +424,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
|||
|
||||
private static native int getOutputCount(long interpreterHandle);
|
||||
|
||||
private static native int getExecutionPlanLength(long interpreterHandle);
|
||||
|
||||
private static native String[] getInputNames(long interpreterHandle);
|
||||
|
||||
private static native String[] getOutputNames(long interpreterHandle);
|
||||
|
|
|
@ -241,6 +241,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
|
|||
return interpreter->outputs()[output_index];
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
|
||||
JNIEnv* env, jclass clazz, jlong handle) {
|
||||
tflite_api_dispatcher::Interpreter* interpreter =
|
||||
convertLongToInterpreter(env, handle);
|
||||
if (interpreter == nullptr) return 0;
|
||||
return static_cast<jint>(interpreter->execution_plan().size());
|
||||
}
|
||||
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
|
||||
jclass clazz,
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.lite;
|
||||
|
||||
/** Utility for interacting with Interpreter in delegate tests. */
|
||||
public abstract class InterpreterTestHelper {
|
||||
|
||||
/**
|
||||
* Returns the number of nodes in the execution plan that are invoked per inference.
|
||||
*
|
||||
* <p>WARNING: This is an experimental API and subject to change.
|
||||
*/
|
||||
public static int executionPlanLength(Interpreter interpreter) {
|
||||
return interpreter.getExecutionPlanLength();
|
||||
}
|
||||
}
|
|
@ -18,12 +18,17 @@ package org.tensorflow.lite.gpu;
|
|||
import static com.google.common.truth.Truth.assertThat;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.PriorityQueue;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.JUnit4;
|
||||
import org.tensorflow.lite.Interpreter;
|
||||
import org.tensorflow.lite.InterpreterTestHelper;
|
||||
import org.tensorflow.lite.TestUtils;
|
||||
|
||||
/** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */
|
||||
|
@ -32,6 +37,9 @@ public final class GpuDelegateTest {
|
|||
|
||||
private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
|
||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||
private static final ByteBuffer MOBILENET_QUANTIZED_MODEL_BUFFER =
|
||||
TestUtils.getTestFileAsBuffer(
|
||||
"third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
|
||||
|
||||
@Test
|
||||
public void testBasic() throws Exception {
|
||||
|
@ -41,7 +49,7 @@ public final class GpuDelegateTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testInterpreterWithGpu() throws Exception {
|
||||
public void testInterpreterWithGpu_FloatModel() throws Exception {
|
||||
Interpreter.Options options = new Interpreter.Options();
|
||||
try (GpuDelegate delegate = new GpuDelegate();
|
||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||
|
@ -60,4 +68,79 @@ public final class GpuDelegateTest {
|
|||
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInterpreterWithGpu_QuantModelRunWithDelegate() throws Exception {
|
||||
ByteBuffer img =
|
||||
TestUtils.getTestImageAsByteBuffer(
|
||||
"tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
|
||||
|
||||
Interpreter.Options options = new Interpreter.Options();
|
||||
try (GpuDelegate delegate =
|
||||
new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true));
|
||||
Interpreter interpreter =
|
||||
new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||
byte[][] output = new byte[1][1001];
|
||||
interpreter.run(img, output);
|
||||
// Should be only 1 node (Delegate) in the execution plan.
|
||||
assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(1);
|
||||
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
|
||||
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
|
||||
// 653 == "military uniform"
|
||||
assertThat(getTopKLabels(output, 3)).contains(653);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInterpreterWithGpu_QuantModelRunOnCPU() throws Exception {
|
||||
ByteBuffer img =
|
||||
TestUtils.getTestImageAsByteBuffer(
|
||||
"tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
|
||||
|
||||
Interpreter.Options options = new Interpreter.Options();
|
||||
try (GpuDelegate delegate = new GpuDelegate();
|
||||
Interpreter interpreter =
|
||||
new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||
byte[][] output = new byte[1][1001];
|
||||
interpreter.run(img, output);
|
||||
// Original execution plan remains since default behavior doesn't allow quantized models.
|
||||
assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(31);
|
||||
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
|
||||
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
|
||||
// 653 == "military uniform"
|
||||
assertThat(getTopKLabels(output, 3)).contains(653);
|
||||
}
|
||||
}
|
||||
|
||||
private static ArrayList<Integer> getTopKLabels(byte[][] byteLabels, int k) {
|
||||
float[][] labels = new float[1][1001];
|
||||
for (int i = 0; i < byteLabels[0].length; ++i) {
|
||||
labels[0][i] = (byteLabels[0][i] & 0xff) / 255.0f;
|
||||
}
|
||||
return getTopKLabels(labels, k);
|
||||
}
|
||||
|
||||
private static ArrayList<Integer> getTopKLabels(float[][] labels, int k) {
|
||||
PriorityQueue<Map.Entry<Integer, Float>> pq =
|
||||
new PriorityQueue<>(
|
||||
k,
|
||||
new Comparator<Map.Entry<Integer, Float>>() {
|
||||
@Override
|
||||
public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) {
|
||||
// Intentionally reversed to put high confidence at the head of the queue.
|
||||
return o1.getValue().compareTo(o2.getValue()) * -1;
|
||||
}
|
||||
});
|
||||
|
||||
for (int i = 0; i < labels[0].length; ++i) {
|
||||
pq.add(new AbstractMap.SimpleEntry<>(i, labels[0][i]));
|
||||
}
|
||||
|
||||
final ArrayList<Integer> topKLabels = new ArrayList<>();
|
||||
int topKLabelsSize = Math.min(pq.size(), k);
|
||||
for (int i = 0; i < topKLabelsSize; ++i) {
|
||||
topKLabels.add(pq.poll().getKey());
|
||||
}
|
||||
return topKLabels;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue