Merge pull request #16833 from Johnson145/lite_android_float_inception

Tensorflow Lite demo app for Android: add support for floating point models as Inception-v3
This commit is contained in:
Martin Wicke 2018-02-15 16:04:11 -08:00 committed by GitHub
commit 25b4086bb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1318 additions and 40 deletions

View File

@ -6,7 +6,7 @@ TensorFlow Lite uses many techniques for achieving low latency like optimizing t
![image](g3doc/TFLite-Architecture.jpg)
# Getting Started with an Android Demo App
This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo.
This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo.
There are 3 ways to get the demo app to your device
- Download the prebuilt binary or
@ -29,9 +29,16 @@ The simplest way to compile the demo app, and try out changes to the project cod
- Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings).
- Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project.
- Click through installing all the Gradle extensions it requests.
- Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
- unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
`tensorflow/contrib/lite/java/demo/app/src/main/assets/`
- Either
- Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
- unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
`tensorflow/contrib/lite/java/demo/app/src/main/assets/`
- Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip)
- unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory
- change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from
`classifier = new ImageClassifierQuantizedMobileNet(getActivity());`
to
`classifier = new ImageClassifierFloatInception(getActivity());`
- Build and run the demo app
## Building TensorFlow Lite and the demo app from source
@ -84,7 +91,7 @@ Currently, we only support building the Android demo app within a Python 2
environment (due to a Bazel bug).
### More about the demo
The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app.
The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app.
# iOS Demo App

File diff suppressed because it is too large Load Diff

View File

@ -296,7 +296,8 @@ public class Camera2BasicFragment extends Fragment
public void onActivityCreated(Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
try {
classifier = new ImageClassifier(getActivity());
// create either a new ImageClassifierQuantizedMobileNet or an ImageClassifierFloatInception
classifier = new ImageClassifierQuantizedMobileNet(getActivity());
} catch (IOException e) {
Log.e(TAG, "Failed to initialize an image classifier.");
}
@ -659,7 +660,7 @@ public class Camera2BasicFragment extends Fragment
return;
}
Bitmap bitmap =
textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);
textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY());
String textToShow = classifier.classifyFrame(bitmap);
bitmap.recycle();
showToast(textToShow);

View File

@ -20,6 +20,9 @@ import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.os.SystemClock;
import android.util.Log;
import org.tensorflow.lite.Interpreter;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
@ -34,20 +37,15 @@ import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;
/** Classifies images with Tensorflow Lite. */
public class ImageClassifier {
/**
* Classifies images with Tensorflow Lite.
*/
public abstract class ImageClassifier {
/** Tag for the {@link Log}. */
private static final String TAG = "TfLiteCameraDemo";
/** Name of the model file stored in Assets. */
private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite";
/** Name of the label file stored in Assets. */
private static final String LABEL_PATH = "labels.txt";
/** Number of results to show in the UI. */
private static final int RESULTS_TO_SHOW = 3;
@ -56,23 +54,18 @@ public class ImageClassifier {
private static final int DIM_PIXEL_SIZE = 3;
static final int DIM_IMG_SIZE_X = 224;
static final int DIM_IMG_SIZE_Y = 224;
/* Preallocated buffers for storing image data in. */
private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
/** An instance of the driver class to run model inference with Tensorflow Lite. */
private Interpreter tflite;
protected Interpreter tflite;
/** Labels corresponding to the output of the vision model. */
private List<String> labelList;
/** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
private ByteBuffer imgData = null;
protected ByteBuffer imgData = null;
/** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
private byte[][] labelProbArray = null;
/** multi-stage low pass filter * */
private float[][] filterLabelProbArray = null;
@ -95,10 +88,10 @@ public class ImageClassifier {
labelList = loadLabelList(activity);
imgData =
ByteBuffer.allocateDirect(
DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
DIM_BATCH_SIZE * getImageSizeX() * getImageSizeY() * DIM_PIXEL_SIZE *
getNumBytesPerChannel());
imgData.order(ByteOrder.nativeOrder());
labelProbArray = new byte[1][labelList.size()];
filterLabelProbArray = new float[FILTER_STAGES][labelList.size()];
filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()];
Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
}
@ -111,7 +104,7 @@ public class ImageClassifier {
convertBitmapToByteBuffer(bitmap);
// Here's where the magic happens!!!
long startTime = SystemClock.uptimeMillis();
tflite.run(imgData, labelProbArray);
runInference();
long endTime = SystemClock.uptimeMillis();
Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
@ -125,12 +118,12 @@ public class ImageClassifier {
}
void applyFilter() {
int numLabels = labelList.size();
int numLabels = getNumLabels();
// Low pass filter `labelProbArray` into the first stage of the filter.
for (int j = 0; j < numLabels; ++j) {
filterLabelProbArray[0][j] +=
FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]);
FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]);
}
// Low pass filter each stage into the next.
for (int i = 1; i < FILTER_STAGES; ++i) {
@ -142,7 +135,7 @@ public class ImageClassifier {
// Copy the last stage filter output back to `labelProbArray`.
for (int j = 0; j < numLabels; ++j) {
labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j];
setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]);
}
}
@ -156,7 +149,7 @@ public class ImageClassifier {
private List<String> loadLabelList(Activity activity) throws IOException {
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath())));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
@ -167,7 +160,7 @@ public class ImageClassifier {
/** Memory-map the model file in Assets. */
private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
@ -185,12 +178,10 @@ public class ImageClassifier {
// Convert the image to floating point.
int pixel = 0;
long startTime = SystemClock.uptimeMillis();
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
for (int i = 0; i < getImageSizeX(); ++i) {
for (int j = 0; j < getImageSizeY(); ++j) {
final int val = intValues[pixel++];
imgData.put((byte) ((val >> 16) & 0xFF));
imgData.put((byte) ((val >> 8) & 0xFF));
imgData.put((byte) (val & 0xFF));
addPixelValue(val);
}
}
long endTime = SystemClock.uptimeMillis();
@ -199,9 +190,9 @@ public class ImageClassifier {
/** Prints top-K labels, to be shown in UI as the results. */
private String printTopKLabels() {
for (int i = 0; i < labelList.size(); ++i) {
for (int i = 0; i < getNumLabels(); ++i) {
sortedLabels.add(
new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f));
new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i)));
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
@ -214,4 +205,80 @@ public class ImageClassifier {
}
return textToShow;
}
/**
* Get the name of the model file stored in Assets.
* @return
*/
protected abstract String getModelPath();
/**
* Get the name of the label file stored in Assets.
* @return
*/
protected abstract String getLabelPath();
/**
* Get the image size along the x axis.
* @return
*/
protected abstract int getImageSizeX();
/**
* Get the image size along the y axis.
* @return
*/
protected abstract int getImageSizeY();
/**
* Get the number of bytes that is used to store a single color channel value.
* @return
*/
protected abstract int getNumBytesPerChannel();
/**
* Add pixelValue to byteBuffer.
* @param pixelValue
*/
protected abstract void addPixelValue(int pixelValue);
/**
* Read the probability value for the specified label
* This is either the original value as it was read from the net's output or the updated value
* after the filter was applied.
* @param labelIndex
* @return
*/
protected abstract float getProbability(int labelIndex);
/**
* Set the probability value for the specified label.
* @param labelIndex
* @param value
*/
protected abstract void setProbability(int labelIndex, Number value);
/**
* Get the normalized probability value for the specified label.
* This is the final value as it will be shown to the user.
* @return
*/
protected abstract float getNormalizedProbability(int labelIndex);
/**
* Run inference using the prepared input in {@link #imgData}.
* Afterwards, the result will be provided by getProbability().
*
* This additional method is necessary, because we don't have a common base for different
* primitive data types.
*/
protected abstract void runInference();
/**
* Get the total number of labels.
* @return
*/
protected int getNumLabels() {
return labelList.size();
}
}

View File

@ -0,0 +1,105 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.tflitecamerademo;
import android.app.Activity;
import java.io.IOException;
/**
* This classifier works with the Inception-v3 slim model.
* It applies floating point inference rather than using a quantized model.
*/
public class ImageClassifierFloatInception extends ImageClassifier {
/**
* The inception net requires additional normalization of the used input.
*/
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
/**
* An array to hold inference results, to be feed into Tensorflow Lite as outputs.
* This isn't part of the super class, because we need a primitive array here.
*/
private float[][] labelProbArray = null;
/**
* Initializes an {@code ImageClassifier}.
*
* @param activity
*/
ImageClassifierFloatInception(Activity activity) throws IOException {
super(activity);
labelProbArray = new float[1][getNumLabels()];
}
@Override
protected String getModelPath() {
// you can download this file from
// https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip
return "inceptionv3_slim_2016.tflite";
}
@Override
protected String getLabelPath() {
return "labels_imagenet_slim.txt";
}
@Override
protected int getImageSizeX() {
return 299;
}
@Override
protected int getImageSizeY() {
return 299;
}
@Override
protected int getNumBytesPerChannel() {
// a 32bit float value requires 4 bytes
return 4;
}
@Override
protected void addPixelValue(int pixelValue) {
imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
@Override
protected float getProbability(int labelIndex) {
return labelProbArray[0][labelIndex];
}
@Override
protected void setProbability(int labelIndex, Number value) {
labelProbArray[0][labelIndex] = value.floatValue();
}
@Override
protected float getNormalizedProbability(int labelIndex) {
// TODO the following value isn't in [0,1] yet, but may be greater. Why?
return getProbability(labelIndex);
}
@Override
protected void runInference() {
tflite.run(imgData, labelProbArray);
}
}

View File

@ -0,0 +1,97 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.example.android.tflitecamerademo;
import android.app.Activity;
import java.io.IOException;
/**
* This classifier works with the quantized MobileNet model.
*/
public class ImageClassifierQuantizedMobileNet extends ImageClassifier {
/**
* An array to hold inference results, to be feed into Tensorflow Lite as outputs.
* This isn't part of the super class, because we need a primitive array here.
*/
private byte[][] labelProbArray = null;
/**
* Initializes an {@code ImageClassifier}.
*
* @param activity
*/
ImageClassifierQuantizedMobileNet(Activity activity) throws IOException {
super(activity);
labelProbArray = new byte[1][getNumLabels()];
}
@Override
protected String getModelPath() {
// you can download this file from
// https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
return "mobilenet_quant_v1_224.tflite";
}
@Override
protected String getLabelPath() {
return "labels_mobilenet_quant_v1_224.txt";
}
@Override
protected int getImageSizeX() {
return 224;
}
@Override
protected int getImageSizeY() {
return 224;
}
@Override
protected int getNumBytesPerChannel() {
// the quantized model uses a single byte only
return 1;
}
@Override
protected void addPixelValue(int pixelValue) {
imgData.put((byte) ((pixelValue >> 16) & 0xFF));
imgData.put((byte) ((pixelValue >> 8) & 0xFF));
imgData.put((byte) (pixelValue & 0xFF));
}
@Override
protected float getProbability(int labelIndex) {
return labelProbArray[0][labelIndex];
}
@Override
protected void setProbability(int labelIndex, Number value) {
labelProbArray[0][labelIndex] = value.byteValue();
}
@Override
protected float getNormalizedProbability(int labelIndex) {
return (labelProbArray[0][labelIndex] & 0xff) / 255.0f;
}
@Override
protected void runInference() {
tflite.run(imgData, labelProbArray);
}
}