Safely read content from InputStream (#12643)

* Safely read content from InputStream

Changed to use `ByteArrayOutputStream` on converting `InputStream` into `byte[]`, instead of relying on `InputStream. available()`.

* Bigger initial buffer size to avoid some copying

* construct byte[] in individual's preference

* fix io exception not caught
This commit is contained in:
resec 2017-09-08 13:04:40 -05:00 committed by Yifei Feng
parent df849b7767
commit 088a880ad6

View File

@ -23,6 +23,7 @@ import android.util.Log;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
@ -78,10 +79,27 @@ public class TensorFlowInferenceInterface {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
try {
loadGraph(is, g);
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
byte[] graphDef = new byte[is.available()];
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
loadGraph(graphDef, g);
is.close();
Log.i(TAG, "Successfully loaded model from '" + model + "'");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
} catch (IOException e) {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
@ -103,10 +121,32 @@ public class TensorFlowInferenceInterface {
this.g = new Graph();
this.sess = new Session(g);
this.runner = sess.runner();
try {
loadGraph(is, g);
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("initializeTensorFlow");
Trace.beginSection("readGraphDef");
}
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
int numBytesRead;
byte[] buf = new byte[16384];
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
baos.write(buf, 0, numBytesRead);
}
byte[] graphDef = baos.toByteArray();
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
}
loadGraph(graphDef, g);
Log.i(TAG, "Successfully loaded model from the input stream");
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // initializeTensorFlow.
}
} catch (IOException e) {
throw new RuntimeException("Failed to load model from the input stream", e);
}
@ -458,27 +498,10 @@ public class TensorFlowInferenceInterface {
}
}
private void loadGraph(InputStream is, Graph g) throws IOException {
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();
if (VERSION.SDK_INT >= 18) {
Trace.beginSection("loadGraph");
Trace.beginSection("readGraphDef");
}
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
byte[] graphDef = new byte[is.available()];
final int numBytesRead = is.read(graphDef);
if (numBytesRead != graphDef.length) {
throw new IOException(
"read error: read only "
+ numBytesRead
+ " of the graph, expected to read "
+ graphDef.length);
}
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // readGraphDef.
Trace.beginSection("importGraphDef");
}
@ -490,7 +513,6 @@ public class TensorFlowInferenceInterface {
if (VERSION.SDK_INT >= 18) {
Trace.endSection(); // importGraphDef.
Trace.endSection(); // loadGraph.
}
final long endMs = System.currentTimeMillis();