Safely read content from InputStream (#12643)
* Safely read content from InputStream Changed to use `ByteArrayOutputStream` on converting `InputStream` into `byte[]`, instead of relying on `InputStream. available()`. * Bigger initial buffer size to avoid some copying * construct byte[] in individual's preference * fix io exception not caught
This commit is contained in:
parent
df849b7767
commit
088a880ad6
@ -23,6 +23,7 @@ import android.util.Log;
|
|||||||
import java.io.FileInputStream;
|
import java.io.FileInputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.DoubleBuffer;
|
import java.nio.DoubleBuffer;
|
||||||
import java.nio.FloatBuffer;
|
import java.nio.FloatBuffer;
|
||||||
@ -78,10 +79,27 @@ public class TensorFlowInferenceInterface {
|
|||||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
loadGraph(is, g);
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.beginSection("initializeTensorFlow");
|
||||||
|
Trace.beginSection("readGraphDef");
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
||||||
|
byte[] graphDef = new byte[is.available()];
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // readGraphDef.
|
||||||
|
}
|
||||||
|
|
||||||
|
loadGraph(graphDef, g);
|
||||||
is.close();
|
is.close();
|
||||||
Log.i(TAG, "Successfully loaded model from '" + model + "'");
|
Log.i(TAG, "Successfully loaded model from '" + model + "'");
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // initializeTensorFlow.
|
||||||
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
throw new RuntimeException("Failed to load model from '" + model + "'", e);
|
||||||
}
|
}
|
||||||
@ -103,10 +121,32 @@ public class TensorFlowInferenceInterface {
|
|||||||
this.g = new Graph();
|
this.g = new Graph();
|
||||||
this.sess = new Session(g);
|
this.sess = new Session(g);
|
||||||
this.runner = sess.runner();
|
this.runner = sess.runner();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
loadGraph(is, g);
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.beginSection("initializeTensorFlow");
|
||||||
|
Trace.beginSection("readGraphDef");
|
||||||
|
}
|
||||||
|
|
||||||
|
int baosInitSize = is.available() > 16384 ? is.available() : 16384;
|
||||||
|
ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize);
|
||||||
|
int numBytesRead;
|
||||||
|
byte[] buf = new byte[16384];
|
||||||
|
while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) {
|
||||||
|
baos.write(buf, 0, numBytesRead);
|
||||||
|
}
|
||||||
|
byte[] graphDef = baos.toByteArray();
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // readGraphDef.
|
||||||
|
}
|
||||||
|
|
||||||
|
loadGraph(graphDef, g);
|
||||||
Log.i(TAG, "Successfully loaded model from the input stream");
|
Log.i(TAG, "Successfully loaded model from the input stream");
|
||||||
|
|
||||||
|
if (VERSION.SDK_INT >= 18) {
|
||||||
|
Trace.endSection(); // initializeTensorFlow.
|
||||||
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to load model from the input stream", e);
|
throw new RuntimeException("Failed to load model from the input stream", e);
|
||||||
}
|
}
|
||||||
@ -458,27 +498,10 @@ public class TensorFlowInferenceInterface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void loadGraph(InputStream is, Graph g) throws IOException {
|
private void loadGraph(byte[] graphDef, Graph g) throws IOException {
|
||||||
final long startMs = System.currentTimeMillis();
|
final long startMs = System.currentTimeMillis();
|
||||||
|
|
||||||
if (VERSION.SDK_INT >= 18) {
|
if (VERSION.SDK_INT >= 18) {
|
||||||
Trace.beginSection("loadGraph");
|
|
||||||
Trace.beginSection("readGraphDef");
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(ashankar): Can we somehow mmap the contents instead of copying them?
|
|
||||||
byte[] graphDef = new byte[is.available()];
|
|
||||||
final int numBytesRead = is.read(graphDef);
|
|
||||||
if (numBytesRead != graphDef.length) {
|
|
||||||
throw new IOException(
|
|
||||||
"read error: read only "
|
|
||||||
+ numBytesRead
|
|
||||||
+ " of the graph, expected to read "
|
|
||||||
+ graphDef.length);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (VERSION.SDK_INT >= 18) {
|
|
||||||
Trace.endSection(); // readGraphDef.
|
|
||||||
Trace.beginSection("importGraphDef");
|
Trace.beginSection("importGraphDef");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -490,7 +513,6 @@ public class TensorFlowInferenceInterface {
|
|||||||
|
|
||||||
if (VERSION.SDK_INT >= 18) {
|
if (VERSION.SDK_INT >= 18) {
|
||||||
Trace.endSection(); // importGraphDef.
|
Trace.endSection(); // importGraphDef.
|
||||||
Trace.endSection(); // loadGraph.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final long endMs = System.currentTimeMillis();
|
final long endMs = System.currentTimeMillis();
|
||||||
|
Loading…
Reference in New Issue
Block a user