Merge pull request #42404 from freedomtan:gpu_delegate_already_allows_quant

PiperOrigin-RevId: 327630219
Change-Id: I04153386bcf1a80af7356446ee279d29fb04d342
This commit is contained in:
TensorFlower Gardener 2020-08-20 08:31:04 -07:00
commit e7d27d8507
2 changed files with 47 additions and 47 deletions

View File

@ -326,58 +326,55 @@ public class Camera2BasicFragment extends Fragment
final int deviceIndex = deviceView.getCheckedItemPosition(); final int deviceIndex = deviceView.getCheckedItemPosition();
final int numThreads = np.getValue(); final int numThreads = np.getValue();
backgroundHandler.post(() -> { backgroundHandler.post(
if (modelIndex == currentModel && deviceIndex == currentDevice () -> {
if (modelIndex == currentModel
&& deviceIndex == currentDevice
&& numThreads == currentNumThreads) { && numThreads == currentNumThreads) {
return; return;
} }
currentModel = modelIndex; currentModel = modelIndex;
currentDevice = deviceIndex; currentDevice = deviceIndex;
currentNumThreads = numThreads; currentNumThreads = numThreads;
// Disable classifier while updating // Disable classifier while updating
if (classifier != null) { if (classifier != null) {
classifier.close(); classifier.close();
classifier = null; classifier = null;
} }
// Lookup names of parameters. // Lookup names of parameters.
String model = modelStrings.get(modelIndex); String model = modelStrings.get(modelIndex);
String device = deviceStrings.get(deviceIndex); String device = deviceStrings.get(deviceIndex);
Log.i(TAG, "Changing model to " + model + " device " + device); Log.i(TAG, "Changing model to " + model + " device " + device);
// Try to load model. // Try to load model.
try { try {
if (model.equals(mobilenetV1Quant)) { if (model.equals(mobilenetV1Quant)) {
classifier = new ImageClassifierQuantizedMobileNet(getActivity()); classifier = new ImageClassifierQuantizedMobileNet(getActivity());
} else if (model.equals(mobilenetV1Float)) { } else if (model.equals(mobilenetV1Float)) {
classifier = new ImageClassifierFloatMobileNet(getActivity()); classifier = new ImageClassifierFloatMobileNet(getActivity());
} else { } else {
showToast("Failed to load model"); showToast("Failed to load model");
} }
} catch (IOException e) { } catch (IOException e) {
Log.d(TAG, "Failed to load", e); Log.d(TAG, "Failed to load", e);
classifier = null; classifier = null;
} }
// Customize the interpreter to the type of device we want to use. // Customize the interpreter to the type of device we want to use.
if (classifier == null) { if (classifier == null) {
return; return;
} }
classifier.setNumThreads(numThreads); classifier.setNumThreads(numThreads);
if (device.equals(cpu)) { if (device.equals(cpu)) {
} else if (device.equals(gpu)) { } else if (device.equals(gpu)) {
if (model.equals(mobilenetV1Quant)) { classifier.useGpu();
showToast("gpu requires float model."); } else if (device.equals(nnApi)) {
classifier = null; classifier.useNNAPI();
} else { }
classifier.useGpu(); });
}
} else if (device.equals(nnApi)) {
classifier.useNNAPI();
}
});
} }
/** Connect the buttons to their event handler. */ /** Connect the buttons to their event handler. */

View File

@ -172,7 +172,10 @@ public abstract class ImageClassifier {
public void useGpu() { public void useGpu() {
if (gpuDelegate == null) { if (gpuDelegate == null) {
gpuDelegate = new GpuDelegate(); GpuDelegate.Options options = new GpuDelegate.Options();
options.setQuantizedModelsAllowed(true);
gpuDelegate = new GpuDelegate(options);
tfliteOptions.addDelegate(gpuDelegate); tfliteOptions.addDelegate(gpuDelegate);
recreateInterpreter(); recreateInterpreter();
} }