diff --git a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java index 8d802ae044a..895f12f0233 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java +++ b/tensorflow/lite/delegates/gpu/java/src/main/java/org/tensorflow/lite/gpu/GpuDelegate.java @@ -62,6 +62,18 @@ public class GpuDelegate implements Delegate, Closeable { return this; } + /** + * Enables running quantized models with the delegate. Defaults to false. + * + *

WARNING: This is an experimental API and subject to change. + * + * @param quantizedModelsAllowed When {@code true}, the GPU may run quantized models. + */ + public Options setQuantizedModelsAllowed(boolean quantizedModelsAllowed) { + this.quantizedModelsAllowed = quantizedModelsAllowed; + return this; + } + /** * Sets the inference preference for precision/compilation/runtime tradeoffs. * @@ -74,11 +86,16 @@ public class GpuDelegate implements Delegate, Closeable { } boolean precisionLossAllowed = true; + boolean quantizedModelsAllowed = false; int inferencePreference = INFERENCE_PREFERENCE_FAST_SINGLE_ANSWER; } public GpuDelegate(Options options) { - delegateHandle = createDelegate(options.precisionLossAllowed, options.inferencePreference); + delegateHandle = + createDelegate( + options.precisionLossAllowed, + options.quantizedModelsAllowed, + options.inferencePreference); } public GpuDelegate() { @@ -107,7 +124,8 @@ public class GpuDelegate implements Delegate, Closeable { System.loadLibrary(TFLITE_GPU_LIB); } - private static native long createDelegate(boolean precisionLossAllowed, int preference); + private static native long createDelegate( + boolean precisionLossAllowed, boolean quantizedModelsAllowed, int preference); private static native void deleteDelegate(long delegateHandle); } diff --git a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc index 089e2c2f816..900cc0e0d75 100644 --- a/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc +++ b/tensorflow/lite/delegates/gpu/java/src/main/native/gpu_delegate_jni.cc @@ -23,7 +23,7 @@ extern "C" { JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate( JNIEnv* env, jclass clazz, jboolean precision_loss_allowed, - jint inference_preference) { + jboolean quantized_models_allowed, jint inference_preference) { TfLiteGpuDelegateOptionsV2 options = TfLiteGpuDelegateOptionsV2Default(); if (precision_loss_allowed == JNI_TRUE) { options.inference_priority1 = TFLITE_GPU_INFERENCE_PRIORITY_MIN_LATENCY; @@ -31,6 +31,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_gpu_GpuDelegate_createDelegate( TFLITE_GPU_INFERENCE_PRIORITY_MIN_MEMORY_USAGE; options.inference_priority3 = TFLITE_GPU_INFERENCE_PRIORITY_MAX_PRECISION; } + options.experimental_flags = TFLITE_GPU_EXPERIMENTAL_FLAGS_NONE; + if (quantized_models_allowed) { + options.experimental_flags |= TFLITE_GPU_EXPERIMENTAL_FLAGS_ENABLE_QUANT; + } options.inference_preference = static_cast(inference_preference); return reinterpret_cast(TfLiteGpuDelegateV2Create(&options)); } diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index 46cd1be25cb..5eb5e8ab023 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -353,6 +353,7 @@ filegroup( filegroup( name = "portable_gpu_tests", srcs = [ + "src/test/java/org/tensorflow/lite/InterpreterTestHelper.java", "src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java", ], visibility = ["//visibility:public"], diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index d191b550d8f..5625ef98bb6 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -491,6 +491,11 @@ public final class Interpreter implements AutoCloseable { wrapper.resetVariableTensors(); } + int getExecutionPlanLength() { + checkNotClosed(); + return wrapper.getExecutionPlanLength(); + } + /** Release resources associated with the {@code Interpreter}. */ @Override public void close() { diff --git a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index a22d7241587..8eb3c66f3b5 100644 --- a/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -324,6 +324,11 @@ final class NativeInterpreterWrapper implements AutoCloseable { return outputTensor; } + /** Gets the number of ops in the execution plan. */ + int getExecutionPlanLength() { + return getExecutionPlanLength(interpreterHandle); + } + private void applyDelegates(Interpreter.Options options) { // First apply the flex delegate if necessary. This ensures the graph is fully resolved before // applying other delegates. @@ -419,6 +424,8 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native int getOutputCount(long interpreterHandle); + private static native int getExecutionPlanLength(long interpreterHandle); + private static native String[] getInputNames(long interpreterHandle); private static native String[] getOutputNames(long interpreterHandle); diff --git a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 971aa5efd7a..690b58ac1f4 100644 --- a/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -241,6 +241,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex( return interpreter->outputs()[output_index]; } +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getExecutionPlanLength( + JNIEnv* env, jclass clazz, jlong handle) { + tflite_api_dispatcher::Interpreter* interpreter = + convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + return static_cast(interpreter->execution_plan().size()); +} + JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env, jclass clazz, diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTestHelper.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTestHelper.java new file mode 100644 index 00000000000..34eb47e4dbe --- /dev/null +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/InterpreterTestHelper.java @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Utility for interacting with Interpreter in delegate tests. */ +public abstract class InterpreterTestHelper { + + /** + * Returns the number of nodes in the execution plan that are invoked per inference. + * + *

WARNING: This is an experimental API and subject to change. + */ + public static int executionPlanLength(Interpreter interpreter) { + return interpreter.getExecutionPlanLength(); + } +} diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java index 1fe4a531624..d92a7119aab 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/GpuDelegateTest.java @@ -18,12 +18,17 @@ package org.tensorflow.lite.gpu; import static com.google.common.truth.Truth.assertThat; import java.nio.ByteBuffer; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Comparator; import java.util.HashMap; import java.util.Map; +import java.util.PriorityQueue; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.tensorflow.lite.Interpreter; +import org.tensorflow.lite.InterpreterTestHelper; import org.tensorflow.lite.TestUtils; /** Unit tests for {@link org.tensorflow.lite.gpu.GpuDelegate}. */ @@ -32,6 +37,9 @@ public final class GpuDelegateTest { private static final String MODEL_PATH = "tensorflow/lite/testdata/multi_add.bin"; private static final ByteBuffer MODEL_BUFFER = TestUtils.getTestFileAsBuffer(MODEL_PATH); + private static final ByteBuffer MOBILENET_QUANTIZED_MODEL_BUFFER = + TestUtils.getTestFileAsBuffer( + "third_party/tensorflow/lite/java/demo/app/src/main/assets/mobilenet_v1_1.0_224_quant.tflite"); @Test public void testBasic() throws Exception { @@ -41,7 +49,7 @@ public final class GpuDelegateTest { } @Test - public void testInterpreterWithGpu() throws Exception { + public void testInterpreterWithGpu_FloatModel() throws Exception { Interpreter.Options options = new Interpreter.Options(); try (GpuDelegate delegate = new GpuDelegate(); Interpreter interpreter = new Interpreter(MODEL_BUFFER, options.addDelegate(delegate))) { @@ -60,4 +68,79 @@ public final class GpuDelegateTest { assertThat(parsedOutput1).usingTolerance(0.1f).containsExactly(expected1).inOrder(); } } + + @Test + public void testInterpreterWithGpu_QuantModelRunWithDelegate() throws Exception { + ByteBuffer img = + TestUtils.getTestImageAsByteBuffer( + "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg"); + + Interpreter.Options options = new Interpreter.Options(); + try (GpuDelegate delegate = + new GpuDelegate(new GpuDelegate.Options().setQuantizedModelsAllowed(true)); + Interpreter interpreter = + new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) { + byte[][] output = new byte[1][1001]; + interpreter.run(img, output); + // Should be only 1 node (Delegate) in the execution plan. + assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(1); + assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3}); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001}); + // 653 == "military uniform" + assertThat(getTopKLabels(output, 3)).contains(653); + } + } + + @Test + public void testInterpreterWithGpu_QuantModelRunOnCPU() throws Exception { + ByteBuffer img = + TestUtils.getTestImageAsByteBuffer( + "tensorflow/lite/java/src/testdata/grace_hopper_224.jpg"); + + Interpreter.Options options = new Interpreter.Options(); + try (GpuDelegate delegate = new GpuDelegate(); + Interpreter interpreter = + new Interpreter(MOBILENET_QUANTIZED_MODEL_BUFFER, options.addDelegate(delegate))) { + byte[][] output = new byte[1][1001]; + interpreter.run(img, output); + // Original execution plan remains since default behavior doesn't allow quantized models. + assertThat(InterpreterTestHelper.executionPlanLength(interpreter)).isEqualTo(31); + assertThat(interpreter.getInputTensor(0).shape()).isEqualTo(new int[] {1, 224, 224, 3}); + assertThat(interpreter.getOutputTensor(0).shape()).isEqualTo(new int[] {1, 1001}); + // 653 == "military uniform" + assertThat(getTopKLabels(output, 3)).contains(653); + } + } + + private static ArrayList getTopKLabels(byte[][] byteLabels, int k) { + float[][] labels = new float[1][1001]; + for (int i = 0; i < byteLabels[0].length; ++i) { + labels[0][i] = (byteLabels[0][i] & 0xff) / 255.0f; + } + return getTopKLabels(labels, k); + } + + private static ArrayList getTopKLabels(float[][] labels, int k) { + PriorityQueue> pq = + new PriorityQueue<>( + k, + new Comparator>() { + @Override + public int compare(Map.Entry o1, Map.Entry o2) { + // Intentionally reversed to put high confidence at the head of the queue. + return o1.getValue().compareTo(o2.getValue()) * -1; + } + }); + + for (int i = 0; i < labels[0].length; ++i) { + pq.add(new AbstractMap.SimpleEntry<>(i, labels[0][i])); + } + + final ArrayList topKLabels = new ArrayList<>(); + int topKLabelsSize = Math.min(pq.size(), k); + for (int i = 0; i < topKLabelsSize; ++i) { + topKLabels.add(pq.poll().getKey()); + } + return topKLabels; + } }