From bd65cda4051a042003096096ab9b05c73a15e65c Mon Sep 17 00:00:00 2001 From: Asim Shankar <ashankar@google.com> Date: Tue, 6 Dec 2016 16:59:55 -0800 Subject: [PATCH] Java: LabelImage using Inception example. Command-line program to: - Construct an image normalization graph G1 - Execute G1 in a session to produce a tensor of a batch of images - Import a pre-trained inception model into a graph G2 - Execute G2 in a session to find the best matching label In other words, the Java equivalent of: - C++ example https://github.com/tensorflow/tensorflow/tree/abbb4c1/tensorflow/examples/label_image - Go example https://github.com/tensorflow/tensorflow/blob/c427b7e89d1498b78d361bfe7345b2636a438893/tensorflow/go/example_inception_inference_test.go Another step in the journey that is #5 Change: 141247499 --- tensorflow/examples/label_image/README.md | 12 +- tensorflow/java/README.md | 8 +- .../main/java/org/tensorflow/examples/BUILD | 6 +- .../java/org/tensorflow/examples/Example.java | 29 --- .../org/tensorflow/examples/LabelImage.java | 208 ++++++++++++++++++ 5 files changed, 223 insertions(+), 40 deletions(-) delete mode 100644 tensorflow/java/src/main/java/org/tensorflow/examples/Example.java create mode 100644 tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java diff --git a/tensorflow/examples/label_image/README.md b/tensorflow/examples/label_image/README.md index e427ff78453..62385312b6f 100644 --- a/tensorflow/examples/label_image/README.md +++ b/tensorflow/examples/label_image/README.md @@ -1,7 +1,10 @@ # TensorFlow C++ Image Recognition Demo This example shows how you can load a pre-trained TensorFlow network and use it -to recognize objects in images. +to recognize objects in images in C++. For Java see the [Java +README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java), +and for Go see the [godoc +example](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#ex-package). ## Description @@ -10,9 +13,9 @@ in on the command line. ## To build/install/run -The TensorFlow `GraphDef` that contains the model definition and weights -is not packaged in the repo because of its size. Instead, you must -first download the file to the `data` directory in the source tree: +The TensorFlow `GraphDef` that contains the model definition and weights is not +packaged in the repo because of its size. Instead, you must first download the +file to the `data` directory in the source tree: ```bash $ wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip -O tensorflow/examples/label_image/data/inception_dec_2015.zip @@ -49,6 +52,7 @@ I tensorflow/examples/label_image/main.cc:207] academic gown (896): 0.0232407 I tensorflow/examples/label_image/main.cc:207] bow tie (817): 0.0157355 I tensorflow/examples/label_image/main.cc:207] bolo tie (940): 0.0145023 ``` + In this case, we're using the default image of Admiral Grace Hopper, and you can see the network correctly spots she's wearing a military uniform, with a high score of 0.6. diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md index d9bee5e342e..1eea76c48a6 100644 --- a/tensorflow/java/README.md +++ b/tensorflow/java/README.md @@ -40,7 +40,7 @@ bazel build -c opt \ //tensorflow/java:libtensorflow-jni ``` -## Example Usage +## Example ### With bazel @@ -48,7 +48,7 @@ Add a dependency on `//tensorflow/java:tensorflow` to the `java_binary` or `java_library` rule. For example: ```sh -bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example +bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:label_image ``` ### With `javac` @@ -58,7 +58,7 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example ```sh javac \ -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar \ - ./src/main/java/org/tensorflow/examples/Example.java + ./src/main/java/org/tensorflow/examples/LabelImage.java ``` - Make `libtensorflow.jar` and `libtensorflow-jni.so` @@ -68,5 +68,5 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example java \ -Djava.library.path=../../bazel-bin/tensorflow/java \ -cp ../../bazel-bin/tensorflow/java/libtensorflow.jar:./src/main/java \ - org.tensorflow.examples.Example + org.tensorflow.examples.LabelImage ``` diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD index 529287a0381..5f9aefef4ce 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD +++ b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD @@ -6,9 +6,9 @@ package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 java_binary( - name = "example", - srcs = ["Example.java"], - main_class = "org.tensorflow.examples.Example", + name = "label_image", + srcs = ["LabelImage.java"], + main_class = "org.tensorflow.examples.LabelImage", deps = ["//tensorflow/java:tensorflow"], ) diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java b/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java deleted file mode 100644 index 630632087a2..00000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -package org.tensorflow.examples; - -import org.tensorflow.TensorFlow; - -/** - * Sample usage of the TensorFlow Java library. - * - * <p>This sample should become more useful as functionality is added to the API. - */ -public class Example { - public static void main(String[] args) { - System.out.println("TensorFlow version: " + TensorFlow.version()); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java new file mode 100644 index 00000000000..740248a29bb --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java @@ -0,0 +1,208 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.examples; + +import java.io.IOException; +import java.io.PrintStream; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import org.tensorflow.DataType; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.TensorFlow; + +/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */ +public class LabelImage { + private static void printUsage(PrintStream s) { + final String url = + "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"; + s.println( + "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)"); + s.println("to label JPEG images."); + s.println("TensorFlow version: " + TensorFlow.version()); + s.println(); + s.println("Usage: label_image <model dir> <image file>"); + s.println(); + s.println("Where:"); + s.println("<model dir> is a directory containing the unzipped contents of the inception model"); + s.println(" (from " + url + ")"); + s.println("<image file> is the path to a JPEG image file"); + } + + public static void main(String[] args) { + if (args.length != 2) { + printUsage(System.err); + System.exit(1); + } + String modelDir = args[0]; + String imageFile = args[1]; + + byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb")); + List<String> labels = + readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt")); + byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile)); + + try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) { + float[] labelProbabilities = executeInceptionGraph(graphDef, image); + int bestLabelIdx = maxIndex(labelProbabilities); + System.out.println( + String.format( + "BEST MATCH: %s (%.2f%% likely)", + labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f)); + } + } + + private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) { + try (Graph g = new Graph()) { + GraphBuilder b = new GraphBuilder(g); + // Some constants specific to the pre-trained model at: + // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip + // + // - The model was trained with images scaled to 224x224 pixels. + // - The colors, represented as R, G, B in 1-byte each were converted to + // float using (value - Mean)/Scale. + final int H = 224; + final int W = 224; + final float mean = 117f; + final float scale = 1f; + + // Since the graph is being constructed once per execution here, we can use a constant for the + // input image. If the graph were to be re-used for multiple input images, a placeholder would + // have been more appropriate. + final Output input = b.constant("input", imageBytes); + final Output output = + b.div( + b.sub( + b.resizeBilinear( + b.expandDims( + b.cast(b.decodeJpeg(input, 3), DataType.FLOAT), + b.constant("make_batch", 0)), + b.constant("size", new int[] {H, W})), + b.constant("mean", mean)), + b.constant("scale", scale)); + try (Session s = new Session(g)) { + return s.runner().fetch(output.op().name()).run().get(0); + } + } + } + + private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) { + try (Graph g = new Graph()) { + g.importGraphDef(graphDef); + try (Session s = new Session(g); + Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) { + final long[] rshape = result.shape(); + if (result.numDimensions() != 2 || rshape[0] != 1) { + throw new RuntimeException( + String.format( + "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s", + Arrays.toString(rshape))); + } + int nlabels = (int) rshape[1]; + return result.copyTo(new float[1][nlabels])[0]; + } + } + } + + private static int maxIndex(float[] probabilities) { + int best = 0; + for (int i = 1; i < probabilities.length; ++i) { + if (probabilities[i] > probabilities[best]) { + best = i; + } + } + return best; + } + + private static byte[] readAllBytesOrExit(Path path) { + try { + return Files.readAllBytes(path); + } catch (IOException e) { + System.err.println("Failed to read [" + path + "]: " + e.getMessage()); + System.exit(1); + } + return null; + } + + private static List<String> readAllLinesOrExit(Path path) { + try { + return Files.readAllLines(path, Charset.forName("UTF-8")); + } catch (IOException e) { + System.err.println("Failed to read [" + path + "]: " + e.getMessage()); + System.exit(0); + } + return null; + } + + // In the fullness of time, equivalents of the methods of this class should be auto-generated from + // the OpDefs linked into libtensorflow-jni.so. That would match what is done in other languages + // like Python, C++ and Go. + static class GraphBuilder { + GraphBuilder(Graph g) { + this.g = g; + } + + Output div(Output x, Output y) { + return binaryOp("Div", x, y); + } + + Output sub(Output x, Output y) { + return binaryOp("Sub", x, y); + } + + Output resizeBilinear(Output images, Output size) { + return binaryOp("ResizeBilinear", images, size); + } + + Output expandDims(Output input, Output dim) { + return binaryOp("ExpandDims", input, dim); + } + + Output cast(Output value, DataType dtype) { + return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0); + } + + Output decodeJpeg(Output contents, long channels) { + return g.opBuilder("DecodeJpeg", "DecodeJpeg") + .addInput(contents) + .setAttr("channels", channels) + .build() + .output(0); + } + + Output constant(String name, Object value) { + try (Tensor t = Tensor.create(value)) { + return g.opBuilder("Const", name) + .setAttr("dtype", t.dataType()) + .setAttr("value", t) + .build() + .output(0); + } + } + + private Output binaryOp(String type, Output in1, Output in2) { + return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0); + } + + private Graph g; + } +}