Android: use return code from initializeTensorFlow to catch any remaining initialization issues before inference is attempted (e.g. decodeJpeg nodes not being stripped from graph).

Change: 143060930
This commit is contained in:
Andrew Harp 2016-12-27 15:21:20 -08:00 committed by TensorFlower Gardener
parent 88d776bfd4
commit c9722179a8
2 changed files with 36 additions and 19 deletions
tensorflow/examples/android/src/org/tensorflow/demo

View File

@ -30,7 +30,6 @@ import android.os.Trace;
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
import java.io.IOException;
import java.util.List;
import java.util.Vector;
import org.tensorflow.demo.OverlayView.DrawCallback;
@ -108,12 +107,25 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
borderedText = new BorderedText(textSizePx);
classifier = new TensorFlowImageClassifier();
try {
classifier.initializeTensorFlow(
getAssets(), MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD,
INPUT_NAME, OUTPUT_NAME);
} catch (final IOException e) {
LOGGER.e(e, "Exception!");
final int initStatus =
classifier.initializeTensorFlow(
getAssets(),
MODEL_FILE,
LABEL_FILE,
NUM_CLASSES,
INPUT_SIZE,
IMAGE_MEAN,
IMAGE_STD,
INPUT_NAME,
OUTPUT_NAME);
if (initStatus != 0) {
LOGGER.e("TF init status != 0: %d", initStatus);
throw new RuntimeException();
}
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
resultsView = (ResultsView) findViewById(R.id.results);

View File

@ -33,7 +33,6 @@ import android.os.Trace;
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
@ -108,19 +107,25 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
tracker = new MultiBoxTracker(getResources().getDisplayMetrics());
detector = new TensorFlowMultiBoxDetector();
try {
detector.initializeTensorFlow(
getAssets(),
MODEL_FILE,
LOCATION_FILE,
NUM_LOCATIONS,
INPUT_SIZE,
IMAGE_MEAN,
IMAGE_STD,
INPUT_NAME,
OUTPUT_NAMES);
} catch (final IOException e) {
LOGGER.e(e, "Exception!");
final int initStatus =
detector.initializeTensorFlow(
getAssets(),
MODEL_FILE,
LOCATION_FILE,
NUM_LOCATIONS,
INPUT_SIZE,
IMAGE_MEAN,
IMAGE_STD,
INPUT_NAME,
OUTPUT_NAMES);
if (initStatus != 0) {
LOGGER.e("TF init status != 0: %d", initStatus);
throw new RuntimeException();
}
} catch (final Exception e) {
throw new RuntimeException("Error initializing TensorFlow!", e);
}
previewWidth = size.getWidth();