Java: LabelImage using Inception example.

Command-line program to:
- Construct an image normalization graph G1
- Execute G1 in a session to produce a tensor of a batch of images
- Import a pre-trained inception model into a graph G2
- Execute G2 in a session to find the best matching label

In other words, the Java equivalent of:
- C++ example https://github.com/tensorflow/tensorflow/tree/abbb4c1/tensorflow/examples/label_image
- Go example
c427b7e89d/tensorflow/go/example_inception_inference_test.go

Another step in the journey that is #5
Change: 141247499
This commit is contained in:
Asim Shankar 2016-12-06 16:59:55 -08:00 committed by TensorFlower Gardener
parent d95969fe12
commit bd65cda405
5 changed files with 223 additions and 40 deletions

View File

@ -1,7 +1,10 @@
# TensorFlow C++ Image Recognition Demo
This example shows how you can load a pre-trained TensorFlow network and use it
to recognize objects in images.
to recognize objects in images in C++. For Java see the [Java
README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java),
and for Go see the [godoc
example](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#ex-package).
## Description
@ -10,9 +13,9 @@ in on the command line.
## To build/install/run
The TensorFlow `GraphDef` that contains the model definition and weights
is not packaged in the repo because of its size. Instead, you must
first download the file to the `data` directory in the source tree:
The TensorFlow `GraphDef` that contains the model definition and weights is not
packaged in the repo because of its size. Instead, you must first download the
file to the `data` directory in the source tree:
```bash
$ wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip -O tensorflow/examples/label_image/data/inception_dec_2015.zip
@ -49,6 +52,7 @@ I tensorflow/examples/label_image/main.cc:207] academic gown (896): 0.0232407
I tensorflow/examples/label_image/main.cc:207] bow tie (817): 0.0157355
I tensorflow/examples/label_image/main.cc:207] bolo tie (940): 0.0145023
```
In this case, we're using the default image of Admiral Grace Hopper, and you can
see the network correctly spots she's wearing a military uniform, with a high
score of 0.6.

View File

@ -40,7 +40,7 @@ bazel build -c opt \
//tensorflow/java:libtensorflow-jni
```
## Example Usage
## Example
### With bazel
@ -48,7 +48,7 @@ Add a dependency on `//tensorflow/java:tensorflow` to the `java_binary` or
`java_library` rule. For example:
```sh
bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:label_image
```
### With `javac`
@ -58,7 +58,7 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
```sh
javac \
-cp ../../bazel-bin/tensorflow/java/libtensorflow.jar \
./src/main/java/org/tensorflow/examples/Example.java
./src/main/java/org/tensorflow/examples/LabelImage.java
```
- Make `libtensorflow.jar` and `libtensorflow-jni.so`
@ -68,5 +68,5 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
java \
-Djava.library.path=../../bazel-bin/tensorflow/java \
-cp ../../bazel-bin/tensorflow/java/libtensorflow.jar:./src/main/java \
org.tensorflow.examples.Example
org.tensorflow.examples.LabelImage
```

View File

@ -6,9 +6,9 @@ package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
java_binary(
name = "example",
srcs = ["Example.java"],
main_class = "org.tensorflow.examples.Example",
name = "label_image",
srcs = ["LabelImage.java"],
main_class = "org.tensorflow.examples.LabelImage",
deps = ["//tensorflow/java:tensorflow"],
)

View File

@ -1,29 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.examples;
import org.tensorflow.TensorFlow;
/**
* Sample usage of the TensorFlow Java library.
*
* <p>This sample should become more useful as functionality is added to the API.
*/
public class Example {
public static void main(String[] args) {
System.out.println("TensorFlow version: " + TensorFlow.version());
}
}

View File

@ -0,0 +1,208 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow.examples;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
public class LabelImage {
private static void printUsage(PrintStream s) {
final String url =
"https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
s.println(
"Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
s.println("to label JPEG images.");
s.println("TensorFlow version: " + TensorFlow.version());
s.println();
s.println("Usage: label_image <model dir> <image file>");
s.println();
s.println("Where:");
s.println("<model dir> is a directory containing the unzipped contents of the inception model");
s.println(" (from " + url + ")");
s.println("<image file> is the path to a JPEG image file");
}
public static void main(String[] args) {
if (args.length != 2) {
printUsage(System.err);
System.exit(1);
}
String modelDir = args[0];
String imageFile = args[1];
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
List<String> labels =
readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
String.format(
"BEST MATCH: %s (%.2f%% likely)",
labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
}
}
private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// Some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
//
// - The model was trained with images scaled to 224x224 pixels.
// - The colors, represented as R, G, B in 1-byte each were converted to
// float using (value - Mean)/Scale.
final int H = 224;
final int W = 224;
final float mean = 117f;
final float scale = 1f;
// Since the graph is being constructed once per execution here, we can use a constant for the
// input image. If the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
final Output input = b.constant("input", imageBytes);
final Output output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
return s.runner().fetch(output.op().name()).run().get(0);
}
}
}
private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
}
private static int maxIndex(float[] probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
}
}
return best;
}
private static byte[] readAllBytesOrExit(Path path) {
try {
return Files.readAllBytes(path);
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(1);
}
return null;
}
private static List<String> readAllLinesOrExit(Path path) {
try {
return Files.readAllLines(path, Charset.forName("UTF-8"));
} catch (IOException e) {
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
System.exit(0);
}
return null;
}
// In the fullness of time, equivalents of the methods of this class should be auto-generated from
// the OpDefs linked into libtensorflow-jni.so. That would match what is done in other languages
// like Python, C++ and Go.
static class GraphBuilder {
GraphBuilder(Graph g) {
this.g = g;
}
Output div(Output x, Output y) {
return binaryOp("Div", x, y);
}
Output sub(Output x, Output y) {
return binaryOp("Sub", x, y);
}
Output resizeBilinear(Output images, Output size) {
return binaryOp("ResizeBilinear", images, size);
}
Output expandDims(Output input, Output dim) {
return binaryOp("ExpandDims", input, dim);
}
Output cast(Output value, DataType dtype) {
return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
}
Output decodeJpeg(Output contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
.output(0);
}
Output constant(String name, Object value) {
try (Tensor t = Tensor.create(value)) {
return g.opBuilder("Const", name)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
.output(0);
}
}
private Output binaryOp(String type, Output in1, Output in2) {
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
}
private Graph g;
}
}