diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java index 8d802ae044a..895f12f0233 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java @@ -62,6 +62,18 @@ public class GpuDelegate implements Delegate, Closeable { return this; } + /** + * Enables running quantized models with the delegate. Defaults to false. + * + *
WARNING: This is an experimental API and subject to change.
+ *
+ * @param quantizedModelsAllowed When {@code true}, the GPU may run quantized models.
+ */
+ public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) {
+ this.quantizedModelsAllowed = quantizedModelsAllowed;
+ return this;
+ }
+
/**
* Sets the inference preference for precision/compilation/runtime tradeoffs.
*
@@ -74,11 +86,16 @@ public class GpuDelegate implements Delegate, Closeable {
}
boolean precisionLossAllowed = true;
+ boolean quantizedModelsAllowed = false;
int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER;
}
public GpuDelegate(Options options) {
- delegateHandle = createDelegate(options.precisionLossAllowed, options.inferencePreference);
+ delegateHandle =
+ createDelegate(
+ options.precisionLossAllowed,
+ options.quantizedModelsAllowed,
+ options.inferencePreference);
}
public GpuDelegate() {
@@ -107,7 +124,8 @@ public class GpuDelegate implements Delegate, Closeable {
System.loadLibrary(TFLITE_GPU_LIB);
}
- private static native long createDelegate(boolean precisionLossAllowed, int preference);
+ private static native long createDelegate(
+ boolean precisionLossAllowed, boolean quantizedModelsAllowed, int preference);
private static native void deleteDelegate(long delegateHandle);
}
diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
index 089e2c2f816..900cc0e0d75 100644
--- a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
+++ b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc
@@ -23,7 +23,7 @@ extern "C" {
JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
JNIEnv* env, jclass clazz, jboolean precision_loss_allowed,
- jint inference_preference) {
+ jboolean quantized_models_allowed, jint inference_preference) {
TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default();
if (precision_loss_allowed == JNI_TRUE) {
options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY;
@@ -31,6 +31,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate(
TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE;
options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION;
}
+ options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE;
+ if (quantized_models_allowed) {
+ options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT;
+ }
options.inference_preference = static_cast WARNING: This is an experimental API and subject to change.
+ */
+ public static int executionPlanLength(Interpreter interpreter) {
+ return interpreter.getExecutionPlanLength();
+ }
+}
diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
index 1fe4a531624..d92a7119aab 100644
--- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
+++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java
@@ -18,12 +18,17 @@ package org.tensorflow.lite.gpu;
import static com.google.common.truth.Truth.assertThat;
import java.nio.ByteBuffer;
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
+import java.util.PriorityQueue;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.lite.Interpreter;
+import org.tensorflow.lite.InterpreterTestHelper;
import org.tensorflow.lite.TestUtils;
/** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */
@@ -32,6 +37,9 @@ public final class GpuDelegateTest {
private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin";
private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH);
+ private static final ByteBuffer MOBILENET_QUANTIZED_MODEL_BUFFER =
+ TestUtils.getTestFileAsBuffer(
+ "third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite");
@Test
public void testBasic() throws Exception {
@@ -41,7 +49,7 @@ public final class GpuDelegateTest {
}
@Test
- public void testInterpreterWithGpu() throws Exception {
+ public void testInterpreterWithGpu_FloatModel() throws Exception {
Interpreter.Options options = new Interpreter.Options();
try (GpuDelegate delegate = new GpuDelegate();
Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) {
@@ -60,4 +68,79 @@ public final class GpuDelegateTest {
assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder();
}
}
+
+ @Test
+ public void testInterpreterWithGpu_QuantModelRunWithDelegate() throws Exception {
+ ByteBuffer img =
+ TestUtils.getTestImageAsByteBuffer(
+ "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
+
+ Interpreter.Options options = new Interpreter.Options();
+ try (GpuDelegate delegate =
+ new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true));
+ Interpreter interpreter =
+ new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
+ byte[][] output = new byte[1][1001];
+ interpreter.run(img, output);
+ // Should be only 1 node (Delegate) in the execution plan.
+ assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
+ // 653 == "military uniform"
+ assertThat(getTopKLabels(output, 3)).contains(653);
+ }
+ }
+
+ @Test
+ public void testInterpreterWithGpu_QuantModelRunOnCPU() throws Exception {
+ ByteBuffer img =
+ TestUtils.getTestImageAsByteBuffer(
+ "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg");
+
+ Interpreter.Options options = new Interpreter.Options();
+ try (GpuDelegate delegate = new GpuDelegate();
+ Interpreter interpreter =
+ new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) {
+ byte[][] output = new byte[1][1001];
+ interpreter.run(img, output);
+ // Original execution plan remains since default behavior doesn't allow quantized models.
+ assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(31);
+ assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3});
+ assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001});
+ // 653 == "military uniform"
+ assertThat(getTopKLabels(output, 3)).contains(653);
+ }
+ }
+
+ private static ArrayList