[tf.lite] Adds a setQuantizedModelsAllowed() Java API for running quant models with GPU delegate
PiperOrigin-RevId: 311402449 Change-Id: I49809a004ad11c4bc9d9e5272472f3b85ea7948f
This commit is contained in:
parent
ffb7db1a81
commit
062cf92d06
|
@ -62,6 +62,18 @@ public class GpuDelegate implements Delegate, Closeable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enables running quantized models with the delegate. Defaults to false.
|
||||||
|
*
|
||||||
|
* <p>WARNING: This is an experimental API and subject to change.
|
||||||
|
*
|
||||||
|
* @param quantizedModelsAllowed When {@code true}, the GPU may run quantized models.
|
||||||
|
*/
|
||||||
|
public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) {
|
||||||
|
this.quantizedModelsAllowed = quantizedModelsAllowed;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the inference preference for precision/compilation/runtime tradeoffs.
|
* Sets the inference preference for precision/compilation/runtime tradeoffs.
|
||||||
*
|
*
|
||||||
|
@ -74,11 +86,16 @@ public class GpuDelegate implements Delegate, Closeable {
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean precisionLossAllowed = true;
|
boolean precisionLossAllowed = true;
|
||||||
|
boolean quantizedModelsAllowed = false;
|
||||||
int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
|
int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
|
||||||
}
|
}
|
||||||
|
|
||||||
public GpuDelegate(Options options) {
|
public GpuDelegate(Options options) {
|
||||||
delegateHandle = createDelegate(options.precisionLossAllowed, options.inferencePreference);
|
delegateHandle =
|
||||||
|
createDelegate(
|
||||||
|
options.precisionLossAllowed,
|
||||||
|
options.quantizedModelsAllowed,
|
||||||
|
options.inferencePreference);
|
||||||
}
|
}
|
||||||
|
|
||||||
public GpuDelegate() {
|
public GpuDelegate() {
|
||||||
|
@ -107,7 +124,8 @@ public class GpuDelegate implements Delegate, Closeable {
|
||||||
System.loadLibrary(TFLITE_GPU_LIB);
|
System.loadLibrary(TFLITE_GPU_LIB);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static native long createDelegate(boolean precisionLossAllowed, int preference);
|
private static native long createDelegate(
|
||||||
|
boolean precisionLossAllowed, boolean quantizedModelsAllowed, int preference);
|
||||||
|
|
||||||
private static native void deleteDelegate(long delegateHandle);
|
private static native void deleteDelegate(long delegateHandle);
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ extern "C" {
|
||||||
|
|
||||||
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
|
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
|
||||||
JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
|
JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
|
||||||
jint inference_preference) {
|
jboolean quantized_models_allowed, jint inference_preference) {
|
||||||
TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default();
|
TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default();
|
||||||
if (precision_loss_allowed == JNI_TRUE) {
|
if (precision_loss_allowed == JNI_TRUE) {
|
||||||
options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
|
options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
|
||||||
|
@ -31,6 +31,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
|
||||||
TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE;
|
TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE;
|
||||||
options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
|
options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
|
||||||
}
|
}
|
||||||
|
options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE;
|
||||||
|
if (quantized_models_allowed) {
|
||||||
|
options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
|
||||||
|
}
|
||||||
options.inference_preference = static_cast<int32_t>(inference_preference);
|
options.inference_preference = static_cast<int32_t>(inference_preference);
|
||||||
return reinterpret_cast<jlong>(TfLiteGpuDelegateV2Create(&options));
|
return reinterpret_cast<jlong>(TfLiteGpuDelegateV2Create(&options));
|
||||||
}
|
}
|
||||||
|
|
|
@ -353,6 +353,7 @@ filegroup(
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "portable_gpu_tests",
|
name = "portable_gpu_tests",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"src/test/java/org/tensorflow/lite/InterpreterTestHelper.java",
|
||||||
"src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java",
|
"src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java",
|
||||||
],
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
|
|
@ -491,6 +491,11 @@ public final class Interpreter implements AutoCloseable {
|
||||||
wrapper.resetVariableTensors();
|
wrapper.resetVariableTensors();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int getExecutionPlanLength() {
|
||||||
|
checkNotClosed();
|
||||||
|
return wrapper.getExecutionPlanLength();
|
||||||
|
}
|
||||||
|
|
||||||
/** Release resources associated with the {@code Interpreter}. */
|
/** Release resources associated with the {@code Interpreter}. */
|
||||||
@Override
|
@Override
|
||||||
public void close() {
|
public void close() {
|
||||||
|
|
|
@ -324,6 +324,11 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||||
return outputTensor;
|
return outputTensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Gets the number of ops in the execution plan. */
|
||||||
|
int getExecutionPlanLength() {
|
||||||
|
return getExecutionPlanLength(interpreterHandle);
|
||||||
|
}
|
||||||
|
|
||||||
private void applyDelegates(Interpreter.Options options) {
|
private void applyDelegates(Interpreter.Options options) {
|
||||||
// First apply the flex delegate if necessary. This ensures the graph is fully resolved before
|
// First apply the flex delegate if necessary. This ensures the graph is fully resolved before
|
||||||
// applying other delegates.
|
// applying other delegates.
|
||||||
|
@ -419,6 +424,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
|
||||||
|
|
||||||
private static native int getOutputCount(long interpreterHandle);
|
private static native int getOutputCount(long interpreterHandle);
|
||||||
|
|
||||||
|
private static native int getExecutionPlanLength(long interpreterHandle);
|
||||||
|
|
||||||
private static native String[] getInputNames(long interpreterHandle);
|
private static native String[] getInputNames(long interpreterHandle);
|
||||||
|
|
||||||
private static native String[] getOutputNames(long interpreterHandle);
|
private static native String[] getOutputNames(long interpreterHandle);
|
||||||
|
|
|
@ -241,6 +241,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
|
||||||
return interpreter->outputs()[output_index];
|
return interpreter->outputs()[output_index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength(
|
||||||
|
JNIEnv* env, jclass clazz, jlong handle) {
|
||||||
|
tflite_api_dispatcher::Interpreter* interpreter =
|
||||||
|
convertLongToInterpreter(env, handle);
|
||||||
|
if (interpreter == nullptr) return 0;
|
||||||
|
return static_cast<jint>(interpreter->execution_plan().size());
|
||||||
|
}
|
||||||
|
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
|
Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
|
||||||
jclass clazz,
|
jclass clazz,
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
package org.tensorflow.lite;
|
||||||
|
|
||||||
|
/** Utility for interacting with Interpreter in delegate tests. */
|
||||||
|
public abstract class InterpreterTestHelper {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the number of nodes in the execution plan that are invoked per inference.
|
||||||
|
*
|
||||||
|
* <p>WARNING: This is an experimental API and subject to change.
|
||||||
|
*/
|
||||||
|
public static int executionPlanLength(Interpreter interpreter) {
|
||||||
|
return interpreter.getExecutionPlanLength();
|
||||||
|
}
|
||||||
|
}
|
|
@ -18,12 +18,17 @@ package org.tensorflow.lite.gpu;
|
||||||
import static com.google.common.truth.Truth.assertThat;
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.PriorityQueue;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.JUnit4;
|
import org.junit.runners.JUnit4;
|
||||||
import org.tensorflow.lite.Interpreter;
|
import org.tensorflow.lite.Interpreter;
|
||||||
|
import org.tensorflow.lite.InterpreterTestHelper;
|
||||||
import org.tensorflow.lite.TestUtils;
|
import org.tensorflow.lite.TestUtils;
|
||||||
|
|
||||||
/** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */
|
/** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */
|
||||||
|
@ -32,6 +37,9 @@ public final class GpuDelegateTest {
|
||||||
|
|
||||||
private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
|
private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
|
||||||
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
|
||||||
|
private static final ByteBuffer MOBILENET_QUANTIZED_MODEL_BUFFER =
|
||||||
|
TestUtils.getTestFileAsBuffer(
|
||||||
|
"third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasic() throws Exception {
|
public void testBasic() throws Exception {
|
||||||
|
@ -41,7 +49,7 @@ public final class GpuDelegateTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testInterpreterWithGpu() throws Exception {
|
public void testInterpreterWithGpu_FloatModel() throws Exception {
|
||||||
Interpreter.Options options = new Interpreter.Options();
|
Interpreter.Options options = new Interpreter.Options();
|
||||||
try (GpuDelegate delegate = new GpuDelegate();
|
try (GpuDelegate delegate = new GpuDelegate();
|
||||||
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
|
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||||
|
@ -60,4 +68,79 @@ public final class GpuDelegateTest {
|
||||||
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
|
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInterpreterWithGpu_QuantModelRunWithDelegate() throws Exception {
|
||||||
|
ByteBuffer img =
|
||||||
|
TestUtils.getTestImageAsByteBuffer(
|
||||||
|
"tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
|
||||||
|
|
||||||
|
Interpreter.Options options = new Interpreter.Options();
|
||||||
|
try (GpuDelegate delegate =
|
||||||
|
new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true));
|
||||||
|
Interpreter interpreter =
|
||||||
|
new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||||
|
byte[][] output = new byte[1][1001];
|
||||||
|
interpreter.run(img, output);
|
||||||
|
// Should be only 1 node (Delegate) in the execution plan.
|
||||||
|
assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(1);
|
||||||
|
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
|
||||||
|
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
|
||||||
|
// 653 == "military uniform"
|
||||||
|
assertThat(getTopKLabels(output, 3)).contains(653);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInterpreterWithGpu_QuantModelRunOnCPU() throws Exception {
|
||||||
|
ByteBuffer img =
|
||||||
|
TestUtils.getTestImageAsByteBuffer(
|
||||||
|
"tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
|
||||||
|
|
||||||
|
Interpreter.Options options = new Interpreter.Options();
|
||||||
|
try (GpuDelegate delegate = new GpuDelegate();
|
||||||
|
Interpreter interpreter =
|
||||||
|
new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
|
||||||
|
byte[][] output = new byte[1][1001];
|
||||||
|
interpreter.run(img, output);
|
||||||
|
// Original execution plan remains since default behavior doesn't allow quantized models.
|
||||||
|
assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(31);
|
||||||
|
assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
|
||||||
|
assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
|
||||||
|
// 653 == "military uniform"
|
||||||
|
assertThat(getTopKLabels(output, 3)).contains(653);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ArrayList<Integer> getTopKLabels(byte[][] byteLabels, int k) {
|
||||||
|
float[][] labels = new float[1][1001];
|
||||||
|
for (int i = 0; i < byteLabels[0].length; ++i) {
|
||||||
|
labels[0][i] = (byteLabels[0][i] & 0xff) / 255.0f;
|
||||||
|
}
|
||||||
|
return getTopKLabels(labels, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ArrayList<Integer> getTopKLabels(float[][] labels, int k) {
|
||||||
|
PriorityQueue<Map.Entry<Integer, Float>> pq =
|
||||||
|
new PriorityQueue<>(
|
||||||
|
k,
|
||||||
|
new Comparator<Map.Entry<Integer, Float>>() {
|
||||||
|
@Override
|
||||||
|
public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) {
|
||||||
|
// Intentionally reversed to put high confidence at the head of the queue.
|
||||||
|
return o1.getValue().compareTo(o2.getValue()) * -1;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
for (int i = 0; i < labels[0].length; ++i) {
|
||||||
|
pq.add(new AbstractMap.SimpleEntry<>(i, labels[0][i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
final ArrayList<Integer> topKLabels = new ArrayList<>();
|
||||||
|
int topKLabelsSize = Math.min(pq.size(), k);
|
||||||
|
for (int i = 0; i < topKLabelsSize; ++i) {
|
||||||
|
topKLabels.add(pq.poll().getKey());
|
||||||
|
}
|
||||||
|
return topKLabels;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue