From 088a880ad6adb4445643f5f5ebaeda32d5c2b1f6 Mon Sep 17 00:00:00 2001 From: resec Date: Fri, 8 Sep 2017 13:04:40 -0500 Subject: [PATCH] Safely read content from InputStream (#12643) * Safely read content from InputStream Changed to use `ByteArrayOutputStream` on converting `InputStream` into `byte[]`, instead of relying on `InputStream. available()`. * Bigger initial buffer size to avoid some copying * construct byte[] in individual's preference * fix io exception not caught --- .../android/TensorFlowInferenceInterface.java | 66 ++++++++++++------- 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 6389ef1f5da..090621f29e0 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -23,6 +23,7 @@ import android.util.Log; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; +import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; @@ -78,10 +79,27 @@ public class TensorFlowInferenceInterface { throw new RuntimeException("Failed to load model from '" + model + "'", e); } } + try { - loadGraph(is, g); + if (VERSION.SDK_INT >= 18) { + Trace.beginSection("initializeTensorFlow"); + Trace.beginSection("readGraphDef"); + } + + // TODO(ashankar): Can we somehow mmap the contents instead of copying them? + byte[] graphDef = new byte[is.available()]; + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // readGraphDef. + } + + loadGraph(graphDef, g); is.close(); Log.i(TAG, "Successfully loaded model from '" + model + "'"); + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // initializeTensorFlow. + } } catch (IOException e) { throw new RuntimeException("Failed to load model from '" + model + "'", e); } @@ -103,10 +121,32 @@ public class TensorFlowInferenceInterface { this.g = new Graph(); this.sess = new Session(g); this.runner = sess.runner(); - + try { - loadGraph(is, g); + if (VERSION.SDK_INT >= 18) { + Trace.beginSection("initializeTensorFlow"); + Trace.beginSection("readGraphDef"); + } + + int baosInitSize = is.available() > 16384 ? is.available() : 16384; + ByteArrayOutputStream baos = new ByteArrayOutputStream(baosInitSize); + int numBytesRead; + byte[] buf = new byte[16384]; + while ((numBytesRead = is.read(buf, 0, buf.length)) != -1) { + baos.write(buf, 0, numBytesRead); + } + byte[] graphDef = baos.toByteArray(); + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // readGraphDef. + } + + loadGraph(graphDef, g); Log.i(TAG, "Successfully loaded model from the input stream"); + + if (VERSION.SDK_INT >= 18) { + Trace.endSection(); // initializeTensorFlow. + } } catch (IOException e) { throw new RuntimeException("Failed to load model from the input stream", e); } @@ -458,27 +498,10 @@ public class TensorFlowInferenceInterface { } } - private void loadGraph(InputStream is, Graph g) throws IOException { + private void loadGraph(byte[] graphDef, Graph g) throws IOException { final long startMs = System.currentTimeMillis(); if (VERSION.SDK_INT >= 18) { - Trace.beginSection("loadGraph"); - Trace.beginSection("readGraphDef"); - } - - // TODO(ashankar): Can we somehow mmap the contents instead of copying them? - byte[] graphDef = new byte[is.available()]; - final int numBytesRead = is.read(graphDef); - if (numBytesRead != graphDef.length) { - throw new IOException( - "read error: read only " - + numBytesRead - + " of the graph, expected to read " - + graphDef.length); - } - - if (VERSION.SDK_INT >= 18) { - Trace.endSection(); // readGraphDef. Trace.beginSection("importGraphDef"); } @@ -490,7 +513,6 @@ public class TensorFlowInferenceInterface { if (VERSION.SDK_INT >= 18) { Trace.endSection(); // importGraphDef. - Trace.endSection(); // loadGraph. } final long endMs = System.currentTimeMillis();