Safely read content from InputStream (#12643)
* Safely read content from InputStream Changed to use `ByteArrayOutputStream` on converting `InputStream` into `byte[]`, instead of relying on `InputStream. available()`. * Bigger initial buffer size to avoid some copying * construct byte[] in individual's preference * fix io exception not caught
This commit is contained in:
parent
df849b7767
commit
088a880ad6
@ -23,6 +23,7 @@ import android.util.Log;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
@ -78,10 +79,27 @@ public class TensorFlowInferenceInterface {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
loadGraph(is, g);
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
||||
byte[] graphDef = new byte[is.available()];
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
}
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
is.close();
|
||||
Log.i(TAG, "Successfully loaded model from '" + model + "'");
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||
}
|
||||
@ -103,10 +121,32 @@ public class TensorFlowInferenceInterface {
|
||||
this.g = new Graph();
|
||||
this.sess = new Session(g);
|
||||
this.runner = sess.runner();
|
||||
|
||||
|
||||
try {
|
||||
loadGraph(is, g);
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("initializeTensorFlow");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
|
||||
int numBytesRead;
|
||||
byte[] buf = new byte[16384];
|
||||
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
|
||||
baos.write(buf, 0, numBytesRead);
|
||||
}
|
||||
byte[] graphDef = baos.toByteArray();
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
}
|
||||
|
||||
loadGraph(graphDef, g);
|
||||
Log.i(TAG, "Successfully loaded model from the input stream");
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // initializeTensorFlow.
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to load model from the input stream", e);
|
||||
}
|
||||
@ -458,27 +498,10 @@ public class TensorFlowInferenceInterface {
|
||||
}
|
||||
}
|
||||
|
||||
private void loadGraph(InputStream is, Graph g) throws IOException {
|
||||
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
|
||||
final long startMs = System.currentTimeMillis();
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.beginSection("loadGraph");
|
||||
Trace.beginSection("readGraphDef");
|
||||
}
|
||||
|
||||
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
||||
byte[] graphDef = new byte[is.available()];
|
||||
final int numBytesRead = is.read(graphDef);
|
||||
if (numBytesRead != graphDef.length) {
|
||||
throw new IOException(
|
||||
"read error: read only "
|
||||
+ numBytesRead
|
||||
+ " of the graph, expected to read "
|
||||
+ graphDef.length);
|
||||
}
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // readGraphDef.
|
||||
Trace.beginSection("importGraphDef");
|
||||
}
|
||||
|
||||
@ -490,7 +513,6 @@ public class TensorFlowInferenceInterface {
|
||||
|
||||
if (VERSION.SDK_INT >= 18) {
|
||||
Trace.endSection(); // importGraphDef.
|
||||
Trace.endSection(); // loadGraph.
|
||||
}
|
||||
|
||||
final long endMs = System.currentTimeMillis();
|
||||
|
Loading…
Reference in New Issue
Block a user