[tf.lite] Adds a setQuantizedModelsAllowed() Java API for running quant models with GPU delegate

PiperOrigin-RevId: 311402449
Change-Id: I49809a004ad11c4bc9d9e5272472f3b85ea7948f
This commit is contained in:
Sachin Joglekar 2020-05-13 14:17:05 -07:00 committed by TensorFlower Gardener
parent ffb7db1a81
commit 062cf92d06
8 changed files with 160 additions and 4 deletions

View File

@ -62,6 +62,18 @@ public class GpuDelegate implements Delegate, Closeable {
return this;
}
/**
* Enables running quantized models with the delegate. Defaults to false.
*
* <p>WARNING: This is an experimental API and subject to change.
*
* @param quantizedModelsAllowed When {@code true}, the GPU may run quantized models.
*/
public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) {
this.quantizedModelsAllowed = quantizedModelsAllowed;
return this;
}
/**
* Sets the inference preference for precision/compilation/runtime tradeoffs.
*
@ -74,11 +86,16 @@ public class GpuDelegate implements Delegate, Closeable {
}
boolean precisionLossAllowed = true;
boolean quantizedModelsAllowed = false;
int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
}
public GpuDelegate(Options options) {
delegateHandle = createDelegate(options.precisionLossAllowed, options.inferencePreference);
delegateHandle =
createDelegate(
options.precisionLossAllowed,
options.quantizedModelsAllowed,
options.inferencePreference);
}
public GpuDelegate() {
@ -107,7 +124,8 @@ public class GpuDelegate implements Delegate, Closeable {
System.loadLibrary(TFLITE_GPU_LIB);
}
private static native long createDelegate(boolean precisionLossAllowed, int preference);
private static native long createDelegate(
boolean precisionLossAllowed, boolean quantizedModelsAllowed, int preference);
private static native void deleteDelegate(long delegateHandle);
}

View File

@ -23,7 +23,7 @@ extern "C" {
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
jint inference_preference) {
jboolean quantized_models_allowed, jint inference_preference) {
TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default();
if (precision_loss_allowed == JNI_TRUE) {
options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
@ -31,6 +31,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE;
options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
}
options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE;
if (quantized_models_allowed) {
options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
}
options.inference_preference = static_cast<int32_t>(inference_preference);
return reinterpret_cast<jlong>(TfLiteGpuDelegateV2Create(&options));
}

View File

@ -353,6 +353,7 @@ filegroup(
filegroup(
name = "portable_gpu_tests",
srcs = [
"src/test/java/org/tensorflow/lite/InterpreterTestHelper.java",
"src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java",
],
visibility = ["//visibility:public"],

View File

@ -491,6 +491,11 @@ public final class Interpreter implements AutoCloseable {
wrapper.resetVariableTensors();
}
int getExecutionPlanLength() {
checkNotClosed();
return wrapper.getExecutionPlanLength();
}
/** Release resources associated with the {@code Interpreter}. */
@Override
public void close() {

View File

@ -324,6 +324,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
return outputTensor;
}
/** Gets the number of ops in the execution plan. */
int getExecutionPlanLength() {
return getExecutionPlanLength(interpreterHandle);
}
private void applyDelegates(Interpreter.Options options) {
// First apply the flex delegate if necessary. This ensures the graph is fully resolved before
// applying other delegates.
@ -419,6 +424,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native int getOutputCount(long interpreterHandle);
private static native int getExecutionPlanLength(long interpreterHandle);
private static native String[] getInputNames(long interpreterHandle);
private static native String[] getOutputNames(long interpreterHandle);

View File

@ -241,6 +241,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
return interpreter->outputs()[output_index];
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
JNIEnv* env, jclass clazz, jlong handle) {
tflite_api_dispatcher::Interpreter* interpreter =
convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
return static_cast<jint>(interpreter->execution_plan().size());
}
JNIEXPORT jint JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
jclass clazz,

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.lite;
/** Utility for interacting with Interpreter in delegate tests. */
public abstract class InterpreterTestHelper {
/**
* Returns the number of nodes in the execution plan that are invoked per inference.
*
* <p>WARNING: This is an experimental API and subject to change.
*/
public static int executionPlanLength(Interpreter interpreter) {
return interpreter.getExecutionPlanLength();
}
}

View File

@ -18,12 +18,17 @@ package org.tensorflow.lite.gpu;
import static com.google.common.truth.Truth.assertThat;
import java.nio.ByteBuffer;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.InterpreterTestHelper;
import org.tensorflow.lite.TestUtils;
/** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */
@ -32,6 +37,9 @@ public final class GpuDelegateTest {
private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
private static final ByteBuffer MOBILENET_QUANTIZED_MODEL_BUFFER =
TestUtils.getTestFileAsBuffer(
"third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
@Test
public void testBasic() throws Exception {
@ -41,7 +49,7 @@ public final class GpuDelegateTest {
}
@Test
public void testInterpreterWithGpu() throws Exception {
public void testInterpreterWithGpu_FloatModel() throws Exception {
Interpreter.Options options = new Interpreter.Options();
try (GpuDelegate delegate = new GpuDelegate();
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
@ -60,4 +68,79 @@ public final class GpuDelegateTest {
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
}
}
@Test
public void testInterpreterWithGpu_QuantModelRunWithDelegate() throws Exception {
ByteBuffer img =
TestUtils.getTestImageAsByteBuffer(
"tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
Interpreter.Options options = new Interpreter.Options();
try (GpuDelegate delegate =
new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true));
Interpreter interpreter =
new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
byte[][] output = new byte[1][1001];
interpreter.run(img, output);
// Should be only 1 node (Delegate) in the execution plan.
assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(1);
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
// 653 == "military uniform"
assertThat(getTopKLabels(output, 3)).contains(653);
}
}
@Test
public void testInterpreterWithGpu_QuantModelRunOnCPU() throws Exception {
ByteBuffer img =
TestUtils.getTestImageAsByteBuffer(
"tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
Interpreter.Options options = new Interpreter.Options();
try (GpuDelegate delegate = new GpuDelegate();
Interpreter interpreter =
new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
byte[][] output = new byte[1][1001];
interpreter.run(img, output);
// Original execution plan remains since default behavior doesn't allow quantized models.
assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(31);
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
// 653 == "military uniform"
assertThat(getTopKLabels(output, 3)).contains(653);
}
}
private static ArrayList<Integer> getTopKLabels(byte[][] byteLabels, int k) {
float[][] labels = new float[1][1001];
for (int i = 0; i < byteLabels[0].length; ++i) {
labels[0][i] = (byteLabels[0][i] & 0xff) / 255.0f;
}
return getTopKLabels(labels, k);
}
private static ArrayList<Integer> getTopKLabels(float[][] labels, int k) {
PriorityQueue<Map.Entry<Integer, Float>> pq =
new PriorityQueue<>(
k,
new Comparator<Map.Entry<Integer, Float>>() {
@Override
public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) {
// Intentionally reversed to put high confidence at the head of the queue.
return o1.getValue().compareTo(o2.getValue()) * -1;
}
});
for (int i = 0; i < labels[0].length; ++i) {
pq.add(new AbstractMap.SimpleEntry<>(i, labels[0][i]));
}
final ArrayList<Integer> topKLabels = new ArrayList<>();
int topKLabelsSize = Math.min(pq.size(), k);
for (int i = 0; i < topKLabelsSize; ++i) {
topKLabels.add(pq.poll().getKey());
}
return topKLabels;
}
}