From bd65cda4051a042003096096ab9b05c73a15e65c Mon Sep 17 00:00:00 2001
From: Asim Shankar <ashankar@google.com>
Date: Tue, 6 Dec 2016 16:59:55 -0800
Subject: [PATCH] Java: LabelImage using Inception example.

Command-line program to:
- Construct an image normalization graph G1
- Execute G1 in a session to produce a tensor of a batch of images
- Import a pre-trained inception model into a graph G2
- Execute G2 in a session to find the best matching label

In other words, the Java equivalent of:
- C++ example https://github.com/tensorflow/tensorflow/tree/abbb4c1/tensorflow/examples/label_image
- Go example
https://github.com/tensorflow/tensorflow/blob/c427b7e89d1498b78d361bfe7345b2636a438893/tensorflow/go/example_inception_inference_test.go

Another step in the journey that is #5
Change: 141247499
---
 tensorflow/examples/label_image/README.md     |  12 +-
 tensorflow/java/README.md                     |   8 +-
 .../main/java/org/tensorflow/examples/BUILD   |   6 +-
 .../java/org/tensorflow/examples/Example.java |  29 ---
 .../org/tensorflow/examples/LabelImage.java   | 208 ++++++++++++++++++
 5 files changed, 223 insertions(+), 40 deletions(-)
 delete mode 100644 tensorflow/java/src/main/java/org/tensorflow/examples/Example.java
 create mode 100644 tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java

diff --git a/tensorflow/examples/label_image/README.md b/tensorflow/examples/label_image/README.md
index e427ff78453..62385312b6f 100644
--- a/tensorflow/examples/label_image/README.md
+++ b/tensorflow/examples/label_image/README.md
@@ -1,7 +1,10 @@
 # TensorFlow C++ Image Recognition Demo
 
 This example shows how you can load a pre-trained TensorFlow network and use it
-to recognize objects in images.
+to recognize objects in images in C++. For Java see the [Java
+README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java),
+and for Go see the [godoc
+example](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#ex-package).
 
 ## Description
 
@@ -10,9 +13,9 @@ in on the command line.
 
 ## To build/install/run
 
-The TensorFlow `GraphDef` that contains the model definition and weights
-is not packaged in the repo because of its size. Instead, you must
-first download the file to the `data` directory in the source tree:
+The TensorFlow `GraphDef` that contains the model definition and weights is not
+packaged in the repo because of its size. Instead, you must first download the
+file to the `data` directory in the source tree:
 
 ```bash
 $ wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip -O tensorflow/examples/label_image/data/inception_dec_2015.zip
@@ -49,6 +52,7 @@ I tensorflow/examples/label_image/main.cc:207] academic gown (896): 0.0232407
 I tensorflow/examples/label_image/main.cc:207] bow tie (817): 0.0157355
 I tensorflow/examples/label_image/main.cc:207] bolo tie (940): 0.0145023
 ```
+
 In this case, we're using the default image of Admiral Grace Hopper, and you can
 see the network correctly spots she's wearing a military uniform, with a high
 score of 0.6.
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
index d9bee5e342e..1eea76c48a6 100644
--- a/tensorflow/java/README.md
+++ b/tensorflow/java/README.md
@@ -40,7 +40,7 @@ bazel build -c opt \
   //tensorflow/java:libtensorflow-jni
 ```
 
-## Example Usage
+## Example
 
 ### With bazel
 
@@ -48,7 +48,7 @@ Add a dependency on `//tensorflow/java:tensorflow` to the `java_binary` or
 `java_library` rule. For example:
 
 ```sh
-bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
+bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:label_image
 ```
 
 ### With `javac`
@@ -58,7 +58,7 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
     ```sh
     javac \
       -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar \
-      ./src/main/java/org/tensorflow/examples/Example.java
+      ./src/main/java/org/tensorflow/examples/LabelImage.java
     ```
 
 -   Make `libtensorflow.jar` and `libtensorflow-jni.so`
@@ -68,5 +68,5 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
     java \
       -Djava.library.path=../../bazel-bin/tensorflow/java \
       -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar:./src/main/java \
-      org.tensorflow.examples.Example
+      org.tensorflow.examples.LabelImage
     ```
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
index 529287a0381..5f9aefef4ce 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
@@ -6,9 +6,9 @@ package(default_visibility = ["//visibility:private"])
 licenses(["notice"])  # Apache 2.0
 
 java_binary(
-    name = "example",
-    srcs = ["Example.java"],
-    main_class = "org.tensorflow.examples.Example",
+    name = "label_image",
+    srcs = ["LabelImage.java"],
+    main_class = "org.tensorflow.examples.LabelImage",
     deps = ["//tensorflow/java:tensorflow"],
 )
 
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java b/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java
deleted file mode 100644
index 630632087a2..00000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java
+++ /dev/null
@@ -1,29 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-package org.tensorflow.examples;
-
-import org.tensorflow.TensorFlow;
-
-/**
- * Sample usage of the TensorFlow Java library.
- *
- * <p>This sample should become more useful as functionality is added to the API.
- */
-public class Example {
-  public static void main(String[] args) {
-    System.out.println("TensorFlow version: " + TensorFlow.version());
-  }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
new file mode 100644
index 00000000000..740248a29bb
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@@ -0,0 +1,208 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.examples;
+
+import java.io.IOException;
+import java.io.PrintStream;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+import org.tensorflow.DataType;
+import org.tensorflow.Graph;
+import org.tensorflow.Output;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.TensorFlow;
+
+/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
+public class LabelImage {
+  private static void printUsage(PrintStream s) {
+    final String url =
+        "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
+    s.println(
+        "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
+    s.println("to label JPEG images.");
+    s.println("TensorFlow version: " + TensorFlow.version());
+    s.println();
+    s.println("Usage: label_image <model dir> <image file>");
+    s.println();
+    s.println("Where:");
+    s.println("<model dir> is a directory containing the unzipped contents of the inception model");
+    s.println("            (from " + url + ")");
+    s.println("<image file> is the path to a JPEG image file");
+  }
+
+  public static void main(String[] args) {
+    if (args.length != 2) {
+      printUsage(System.err);
+      System.exit(1);
+    }
+    String modelDir = args[0];
+    String imageFile = args[1];
+
+    byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
+    List<String> labels =
+        readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
+    byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
+
+    try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
+      float[] labelProbabilities = executeInceptionGraph(graphDef, image);
+      int bestLabelIdx = maxIndex(labelProbabilities);
+      System.out.println(
+          String.format(
+              "BEST MATCH: %s (%.2f%% likely)",
+              labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
+    }
+  }
+
+  private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
+    try (Graph g = new Graph()) {
+      GraphBuilder b = new GraphBuilder(g);
+      // Some constants specific to the pre-trained model at:
+      // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
+      //
+      // - The model was trained with images scaled to 224x224 pixels.
+      // - The colors, represented as R, G, B in 1-byte each were converted to
+      //   float using (value - Mean)/Scale.
+      final int H = 224;
+      final int W = 224;
+      final float mean = 117f;
+      final float scale = 1f;
+
+      // Since the graph is being constructed once per execution here, we can use a constant for the
+      // input image. If the graph were to be re-used for multiple input images, a placeholder would
+      // have been more appropriate.
+      final Output input = b.constant("input", imageBytes);
+      final Output output =
+          b.div(
+              b.sub(
+                  b.resizeBilinear(
+                      b.expandDims(
+                          b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
+                          b.constant("make_batch", 0)),
+                      b.constant("size", new int[] {H, W})),
+                  b.constant("mean", mean)),
+              b.constant("scale", scale));
+      try (Session s = new Session(g)) {
+        return s.runner().fetch(output.op().name()).run().get(0);
+      }
+    }
+  }
+
+  private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
+    try (Graph g = new Graph()) {
+      g.importGraphDef(graphDef);
+      try (Session s = new Session(g);
+          Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
+        final long[] rshape = result.shape();
+        if (result.numDimensions() != 2 || rshape[0] != 1) {
+          throw new RuntimeException(
+              String.format(
+                  "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
+                  Arrays.toString(rshape)));
+        }
+        int nlabels = (int) rshape[1];
+        return result.copyTo(new float[1][nlabels])[0];
+      }
+    }
+  }
+
+  private static int maxIndex(float[] probabilities) {
+    int best = 0;
+    for (int i = 1; i < probabilities.length; ++i) {
+      if (probabilities[i] > probabilities[best]) {
+        best = i;
+      }
+    }
+    return best;
+  }
+
+  private static byte[] readAllBytesOrExit(Path path) {
+    try {
+      return Files.readAllBytes(path);
+    } catch (IOException e) {
+      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
+      System.exit(1);
+    }
+    return null;
+  }
+
+  private static List<String> readAllLinesOrExit(Path path) {
+    try {
+      return Files.readAllLines(path, Charset.forName("UTF-8"));
+    } catch (IOException e) {
+      System.err.println("Failed to read [" + path + "]: " + e.getMessage());
+      System.exit(0);
+    }
+    return null;
+  }
+
+  // In the fullness of time, equivalents of the methods of this class should be auto-generated from
+  // the OpDefs linked into libtensorflow-jni.so. That would match what is done in other languages
+  // like Python, C++ and Go.
+  static class GraphBuilder {
+    GraphBuilder(Graph g) {
+      this.g = g;
+    }
+
+    Output div(Output x, Output y) {
+      return binaryOp("Div", x, y);
+    }
+
+    Output sub(Output x, Output y) {
+      return binaryOp("Sub", x, y);
+    }
+
+    Output resizeBilinear(Output images, Output size) {
+      return binaryOp("ResizeBilinear", images, size);
+    }
+
+    Output expandDims(Output input, Output dim) {
+      return binaryOp("ExpandDims", input, dim);
+    }
+
+    Output cast(Output value, DataType dtype) {
+      return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
+    }
+
+    Output decodeJpeg(Output contents, long channels) {
+      return g.opBuilder("DecodeJpeg", "DecodeJpeg")
+          .addInput(contents)
+          .setAttr("channels", channels)
+          .build()
+          .output(0);
+    }
+
+    Output constant(String name, Object value) {
+      try (Tensor t = Tensor.create(value)) {
+        return g.opBuilder("Const", name)
+            .setAttr("dtype", t.dataType())
+            .setAttr("value", t)
+            .build()
+            .output(0);
+      }
+    }
+
+    private Output binaryOp(String type, Output in1, Output in2) {
+      return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
+    }
+
+    private Graph g;
+  }
+}