Android: show inference stats on debug screen in demo (accessed with volume keys).

Change: 143149923
This commit is contained in:
Andrew Harp 2016-12-28 19:38:03 -08:00 committed by TensorFlower Gardener
parent a081f4b06f
commit a1a3b0c6c3
10 changed files with 147 additions and 52 deletions

View File

@ -76,10 +76,18 @@ public class TensorFlowInferenceInterface {
*/
public native int runInference(String[] outputNames);
/**
* Whether to collect and log stats to logcat during inference via StepStats and StatSummarizer.
* This should only be enabled when needed, as it will add overhead.
*/
public native void enableStatLogging(boolean enabled);
/** Returns the last stat summary string if logging is enabled. */
public native String getStatString();
/**
* Cleans up the native variables associated with this Object. initializeTensorFlow() can then
* be called again to initialize a new session.
*
*/
public native void close();

View File

@ -53,6 +53,9 @@ struct SessionVariables {
int num_runs = 0;
int64 timing_total_us = 0;
bool log_stats = false;
StatSummarizer* summarizer = nullptr;
InputMap input_tensors;
std::vector<std::string> output_tensor_names;
std::vector<tensorflow::Tensor> output_tensors;
@ -129,6 +132,10 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
LOG(INFO) << "GraphDef loaded from " << model_str << " with "
<< tensorflow_graph.node_size() << " nodes.";
// Whether or not stat logging is currently enabled, the StatSummarizer must
// be initialized here with the GraphDef while it is available.
vars->summarizer = new StatSummarizer(tensorflow_graph);
LOG(INFO) << "Creating TensorFlow graph from GraphDef.";
tensorflow::Status s = session->Create(tensorflow_graph);
@ -193,8 +200,28 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
}
vars->output_tensors.clear();
s = vars->session->Run(input_tensors, vars->output_tensor_names, {},
&(vars->output_tensors));
if (vars->log_stats) {
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;
s = vars->session->Run(run_options, input_tensors,
vars->output_tensor_names, {},
&(vars->output_tensors), &run_metadata);
assert(run_metadata.has_step_stats());
const StepStats& step_stats = run_metadata.step_stats();
vars->summarizer->ProcessStepStats(step_stats);
// Print the full output string, not just the abbreviated one returned by
// getStatString().
vars->summarizer->PrintStepStats();
} else {
s = vars->session->Run(input_tensors, vars->output_tensor_names, {},
&(vars->output_tensors));
}
end_time = CurrentWallTimeUs();
const int64 elapsed_time_inf = end_time - start_time;
vars->timing_total_us += elapsed_time_inf;
@ -208,6 +235,24 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
return s.code();
}
JNIEXPORT void JNICALL TENSORFLOW_METHOD(enableStatLogging)(
JNIEnv* env, jobject thiz, jboolean enableStatLogging) {
SessionVariables* vars = GetSessionVars(env, thiz);
vars->log_stats = enableStatLogging;
}
JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(getStatString)(JNIEnv* env,
jobject thiz) {
// Return an abbreviated stat string suitable for displaying on screen.
SessionVariables* vars = GetSessionVars(env, thiz);
std::stringstream ss;
ss << vars->summarizer->GetStatsByMetric("Top 10 CPU",
StatSummarizer::BY_TIME, 10);
ss << vars->summarizer->GetStatsByNodeType();
ss << vars->summarizer->ShortSummary();
return env->NewStringUTF(ss.str().c_str());
}
JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
SessionVariables* vars = GetSessionVars(env, thiz);
@ -216,6 +261,8 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
LOG(ERROR) << "Error closing session: " << s;
}
delete vars->summarizer;
mutex_lock l(mutex_);
std::map<int64, SessionVariables*>& sessions = *GetSessionsSingleton();
sessions.erase(vars->id);

View File

@ -48,6 +48,12 @@ JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
JNIEnv* env, jobject thiz, jobjectArray output_name_strings);
JNIEXPORT void JNICALL TENSORFLOW_METHOD(enableStatLogging)(
JNIEnv* env, jobject thiz, jboolean enableStatLogging);
JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(getStatString)(JNIEnv* env,
jobject thiz);
JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz);
FILL_NODE_SIGNATURE(Float, float);

View File

@ -113,6 +113,15 @@ class Stat {
// See tensorflow/examples/android/jni/tensorflow_jni.cc for an example usage.
class StatSummarizer {
public:
enum SortingMetric {
BY_NAME,
BY_DEFINITION_ORDER,
BY_RUN_ORDER,
BY_TIME,
BY_MEMORY,
BY_TYPE,
};
explicit StatSummarizer(const tensorflow::GraphDef& tensorflow_graph);
// Adds another run's StepStats output to the aggregate counts.
@ -122,6 +131,8 @@ class StatSummarizer {
// format which can be pasted into a spreadsheet for further analysis.
std::string GetOutputString() const;
std::string ShortSummary() const;
// Prints the string returned by GetOutputString().
void PrintStepStats() const;
@ -130,6 +141,10 @@ class StatSummarizer {
std::string GetStatsByNodeType() const;
std::string GetStatsByMetric(const string& title,
SortingMetric sorting_metric,
int num_stats) const;
void Reset() {
run_total_us_.Reset();
memory_.Reset();
@ -153,31 +168,16 @@ class StatSummarizer {
std::vector<TensorDescription> outputs;
};
enum SortingMetric {
BY_NAME,
BY_DEFINITION_ORDER,
BY_RUN_ORDER,
BY_TIME,
BY_MEMORY,
BY_TYPE,
};
void Validate(const Detail* detail, const NodeExecStats& ns) const;
void OrderNodesByMetric(SortingMetric sorting_metric,
std::vector<const Detail*>* details) const;
std::string GetStatsByMetric(const string& title,
SortingMetric sorting_metric,
int num_stats) const;
std::string HeaderString(const string& title) const;
std::string ColumnString(const Detail& detail,
const int64 cumulative_stat_on_node,
const Stat<int64>& stat) const;
std::string ShortSummary() const;
Stat<int64> run_total_us_;
Stat<int64> memory_;

View File

@ -197,11 +197,14 @@ public abstract class CameraActivity extends Activity implements OnImageAvailabl
}
}
public void onSetDebug(boolean debug) {}
@Override
public boolean onKeyDown(final int keyCode, final KeyEvent event) {
if (keyCode == KeyEvent.KEYCODE_VOLUME_DOWN || keyCode == KeyEvent.KEYCODE_VOLUME_UP) {
debug = !debug;
requestRender();
onSetDebug(debug);
return true;
}
return super.onKeyDown(keyCode, event);

View File

@ -21,6 +21,7 @@ import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Typeface;
import android.media.Image;
import android.media.Image.Plane;
import android.media.ImageReader;
@ -97,7 +98,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
return INPUT_SIZE;
}
private static final float TEXT_SIZE_DIP = 18;
private static final float TEXT_SIZE_DIP = 10;
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
@ -105,6 +106,7 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP,
getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
classifier = new TensorFlowImageClassifier();
@ -237,6 +239,11 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
Trace.endSection();
}
@Override
public void onSetDebug(boolean debug) {
classifier.enableStatLogging(debug);
}
private void renderDebug(final Canvas canvas) {
if (!isDebug()) {
return;
@ -252,18 +259,21 @@ public class ClassifierActivity extends CameraActivity implements OnImageAvailab
canvas.drawBitmap(copy, matrix, new Paint());
final Vector<String> lines = new Vector<String>();
if (classifier != null) {
String statString = classifier.getStatString();
String[] statLines = statString.split("\n");
for (String line : statLines) {
lines.add(line);
}
}
lines.add("Frame: " + previewWidth + "x" + previewHeight);
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
lines.add("Rotation: " + sensorOrientation);
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
int lineNum = 0;
for (final String line : lines) {
borderedText.drawText(canvas, 10,
canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum, line);
++lineNum;
}
borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
}
}
}

View File

@ -24,6 +24,7 @@ import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.media.Image;
import android.media.Image.Plane;
import android.media.ImageReader;
@ -67,7 +68,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
private static final boolean MAINTAIN_ASPECT = false;
private static final float TEXT_SIZE_DIP = 18;
private static final float TEXT_SIZE_DIP = 10;
private Integer sensorOrientation;
@ -103,6 +104,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
tracker = new MultiBoxTracker(getResources().getDisplayMetrics());
@ -177,21 +179,21 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
canvas.drawBitmap(copy, matrix, new Paint());
final Vector<String> lines = new Vector<String>();
if (detector != null) {
String statString = detector.getStatString();
String[] statLines = statString.split("\n");
for (String line : statLines) {
lines.add(line);
}
}
lines.add("Frame: " + previewWidth + "x" + previewHeight);
lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
lines.add("Rotation: " + sensorOrientation);
lines.add("Inference time: " + lastProcessingTimeMs + "ms");
int lineNum = 0;
for (final String line : lines) {
borderedText.drawText(
canvas,
10,
canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum,
line);
++lineNum;
}
borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
}
}
});
@ -320,4 +322,9 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
protected int getDesiredPreviewFrameSize() {
return INPUT_SIZE;
}
@Override
public void onSetDebug(boolean debug) {
detector.enableStatLogging(debug);
}
}

View File

@ -170,6 +170,14 @@ public class TensorFlowImageClassifier implements Classifier {
return recognitions;
}
public void enableStatLogging(boolean debug) {
inferenceInterface.enableStatLogging(debug);
}
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();

View File

@ -211,6 +211,14 @@ public class TensorFlowMultiBoxDetector implements Classifier {
return recognitions;
}
public void enableStatLogging(boolean debug) {
inferenceInterface.enableStatLogging(debug);
}
public String getStatString() {
return inferenceInterface.getStatString();
}
@Override
public void close() {
inferenceInterface.close();

View File

@ -21,6 +21,8 @@ import android.graphics.Paint;
import android.graphics.Paint.Align;
import android.graphics.Paint.Style;
import android.graphics.Rect;
import android.graphics.Typeface;
import java.util.Vector;
/**
* A class that encapsulates the tedious bits of rendering legible, bordered text onto a canvas.
@ -68,28 +70,24 @@ public class BorderedText {
this.textSize = textSize;
}
public void setTypeface(Typeface typeface) {
interiorPaint.setTypeface(typeface);
exteriorPaint.setTypeface(typeface);
}
public void drawText(final Canvas canvas, final float posX, final float posY, final String text) {
/*
if (widths == null || widths.length < text.length()) {
widths = new float[text.length()];
positions = new float[text.length() * 2];
}
exteriorPaint.getTextWidths(text, widths);
float lastPosX = posX;
for (int i = 0; i < widths.length; ++i) {
positions[i * 2] = lastPosX;
positions[i * 2 + 1] = posY;
lastPosX += widths[i];
}
*/
//canvas.drawPosText(text, positions, exteriorPaint);
//canvas.drawPosText(text, positions, exteriorPaint);
canvas.drawText(text, posX, posY, exteriorPaint);
canvas.drawText(text, posX, posY, interiorPaint);
}
public void drawLines(Canvas canvas, final float posX, final float posY, Vector<String> lines) {
int lineNum = 0;
for (final String line : lines) {
drawText(canvas, posX, posY - getTextSize() * (lines.size() - lineNum - 1), line);
++lineNum;
}
}
public void setInteriorColor(final int color) {
interiorPaint.setColor(color);
}