Java: LabelImage using Inception example.
Command-line program to:
- Construct an image normalization graph G1
- Execute G1 in a session to produce a tensor of a batch of images
- Import a pre-trained inception model into a graph G2
- Execute G2 in a session to find the best matching label
In other words, the Java equivalent of:
- C++ example https://github.com/tensorflow/tensorflow/tree/abbb4c1/tensorflow/examples/label_image
- Go example
c427b7e89d/tensorflow/go/example_inception_inference_test.go
Another step in the journey that is #5
Change: 141247499
This commit is contained in:
parent
d95969fe12
commit
bd65cda405
@ -1,7 +1,10 @@
|
||||
# TensorFlow C++ Image Recognition Demo
|
||||
|
||||
This example shows how you can load a pre-trained TensorFlow network and use it
|
||||
to recognize objects in images.
|
||||
to recognize objects in images in C++. For Java see the [Java
|
||||
README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java),
|
||||
and for Go see the [godoc
|
||||
example](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#ex-package).
|
||||
|
||||
## Description
|
||||
|
||||
@ -10,9 +13,9 @@ in on the command line.
|
||||
|
||||
## To build/install/run
|
||||
|
||||
The TensorFlow `GraphDef` that contains the model definition and weights
|
||||
is not packaged in the repo because of its size. Instead, you must
|
||||
first download the file to the `data` directory in the source tree:
|
||||
The TensorFlow `GraphDef` that contains the model definition and weights is not
|
||||
packaged in the repo because of its size. Instead, you must first download the
|
||||
file to the `data` directory in the source tree:
|
||||
|
||||
```bash
|
||||
$ wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip -O tensorflow/examples/label_image/data/inception_dec_2015.zip
|
||||
@ -49,6 +52,7 @@ I tensorflow/examples/label_image/main.cc:207] academic gown (896): 0.0232407
|
||||
I tensorflow/examples/label_image/main.cc:207] bow tie (817): 0.0157355
|
||||
I tensorflow/examples/label_image/main.cc:207] bolo tie (940): 0.0145023
|
||||
```
|
||||
|
||||
In this case, we're using the default image of Admiral Grace Hopper, and you can
|
||||
see the network correctly spots she's wearing a military uniform, with a high
|
||||
score of 0.6.
|
||||
|
@ -40,7 +40,7 @@ bazel build -c opt \
|
||||
//tensorflow/java:libtensorflow-jni
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
## Example
|
||||
|
||||
### With bazel
|
||||
|
||||
@ -48,7 +48,7 @@ Add a dependency on `//tensorflow/java:tensorflow` to the `java_binary` or
|
||||
`java_library` rule. For example:
|
||||
|
||||
```sh
|
||||
bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
|
||||
bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:label_image
|
||||
```
|
||||
|
||||
### With `javac`
|
||||
@ -58,7 +58,7 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
|
||||
```sh
|
||||
javac \
|
||||
-cp ../../bazel-bin/tensorflow/java/libtensorflow.jar \
|
||||
./src/main/java/org/tensorflow/examples/Example.java
|
||||
./src/main/java/org/tensorflow/examples/LabelImage.java
|
||||
```
|
||||
|
||||
- Make `libtensorflow.jar` and `libtensorflow-jni.so`
|
||||
@ -68,5 +68,5 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
|
||||
java \
|
||||
-Djava.library.path=../../bazel-bin/tensorflow/java \
|
||||
-cp ../../bazel-bin/tensorflow/java/libtensorflow.jar:./src/main/java \
|
||||
org.tensorflow.examples.Example
|
||||
org.tensorflow.examples.LabelImage
|
||||
```
|
||||
|
@ -6,9 +6,9 @@ package(default_visibility = ["//visibility:private"])
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
java_binary(
|
||||
name = "example",
|
||||
srcs = ["Example.java"],
|
||||
main_class = "org.tensorflow.examples.Example",
|
||||
name = "label_image",
|
||||
srcs = ["LabelImage.java"],
|
||||
main_class = "org.tensorflow.examples.LabelImage",
|
||||
deps = ["//tensorflow/java:tensorflow"],
|
||||
)
|
||||
|
||||
|
@ -1,29 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.examples;
|
||||
|
||||
import org.tensorflow.TensorFlow;
|
||||
|
||||
/**
|
||||
* Sample usage of the TensorFlow Java library.
|
||||
*
|
||||
* <p>This sample should become more useful as functionality is added to the API.
|
||||
*/
|
||||
public class Example {
|
||||
public static void main(String[] args) {
|
||||
System.out.println("TensorFlow version: " + TensorFlow.version());
|
||||
}
|
||||
}
|
@ -0,0 +1,208 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.examples;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.PrintStream;
|
||||
import java.nio.charset.Charset;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.tensorflow.DataType;
|
||||
import org.tensorflow.Graph;
|
||||
import org.tensorflow.Output;
|
||||
import org.tensorflow.Session;
|
||||
import org.tensorflow.Tensor;
|
||||
import org.tensorflow.TensorFlow;
|
||||
|
||||
/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
|
||||
public class LabelImage {
|
||||
private static void printUsage(PrintStream s) {
|
||||
final String url =
|
||||
"https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
|
||||
s.println(
|
||||
"Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
|
||||
s.println("to label JPEG images.");
|
||||
s.println("TensorFlow version: " + TensorFlow.version());
|
||||
s.println();
|
||||
s.println("Usage: label_image <model dir> <image file>");
|
||||
s.println();
|
||||
s.println("Where:");
|
||||
s.println("<model dir> is a directory containing the unzipped contents of the inception model");
|
||||
s.println(" (from " + url + ")");
|
||||
s.println("<image file> is the path to a JPEG image file");
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
if (args.length != 2) {
|
||||
printUsage(System.err);
|
||||
System.exit(1);
|
||||
}
|
||||
String modelDir = args[0];
|
||||
String imageFile = args[1];
|
||||
|
||||
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
|
||||
List<String> labels =
|
||||
readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
|
||||
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
|
||||
|
||||
try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
|
||||
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
|
||||
int bestLabelIdx = maxIndex(labelProbabilities);
|
||||
System.out.println(
|
||||
String.format(
|
||||
"BEST MATCH: %s (%.2f%% likely)",
|
||||
labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
|
||||
}
|
||||
}
|
||||
|
||||
private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
|
||||
try (Graph g = new Graph()) {
|
||||
GraphBuilder b = new GraphBuilder(g);
|
||||
// Some constants specific to the pre-trained model at:
|
||||
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
||||
//
|
||||
// - The model was trained with images scaled to 224x224 pixels.
|
||||
// - The colors, represented as R, G, B in 1-byte each were converted to
|
||||
// float using (value - Mean)/Scale.
|
||||
final int H = 224;
|
||||
final int W = 224;
|
||||
final float mean = 117f;
|
||||
final float scale = 1f;
|
||||
|
||||
// Since the graph is being constructed once per execution here, we can use a constant for the
|
||||
// input image. If the graph were to be re-used for multiple input images, a placeholder would
|
||||
// have been more appropriate.
|
||||
final Output input = b.constant("input", imageBytes);
|
||||
final Output output =
|
||||
b.div(
|
||||
b.sub(
|
||||
b.resizeBilinear(
|
||||
b.expandDims(
|
||||
b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
|
||||
b.constant("make_batch", 0)),
|
||||
b.constant("size", new int[] {H, W})),
|
||||
b.constant("mean", mean)),
|
||||
b.constant("scale", scale));
|
||||
try (Session s = new Session(g)) {
|
||||
return s.runner().fetch(output.op().name()).run().get(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
|
||||
try (Graph g = new Graph()) {
|
||||
g.importGraphDef(graphDef);
|
||||
try (Session s = new Session(g);
|
||||
Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
|
||||
final long[] rshape = result.shape();
|
||||
if (result.numDimensions() != 2 || rshape[0] != 1) {
|
||||
throw new RuntimeException(
|
||||
String.format(
|
||||
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
|
||||
Arrays.toString(rshape)));
|
||||
}
|
||||
int nlabels = (int) rshape[1];
|
||||
return result.copyTo(new float[1][nlabels])[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static int maxIndex(float[] probabilities) {
|
||||
int best = 0;
|
||||
for (int i = 1; i < probabilities.length; ++i) {
|
||||
if (probabilities[i] > probabilities[best]) {
|
||||
best = i;
|
||||
}
|
||||
}
|
||||
return best;
|
||||
}
|
||||
|
||||
private static byte[] readAllBytesOrExit(Path path) {
|
||||
try {
|
||||
return Files.readAllBytes(path);
|
||||
} catch (IOException e) {
|
||||
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
|
||||
System.exit(1);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static List<String> readAllLinesOrExit(Path path) {
|
||||
try {
|
||||
return Files.readAllLines(path, Charset.forName("UTF-8"));
|
||||
} catch (IOException e) {
|
||||
System.err.println("Failed to read [" + path + "]: " + e.getMessage());
|
||||
System.exit(0);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// In the fullness of time, equivalents of the methods of this class should be auto-generated from
|
||||
// the OpDefs linked into libtensorflow-jni.so. That would match what is done in other languages
|
||||
// like Python, C++ and Go.
|
||||
static class GraphBuilder {
|
||||
GraphBuilder(Graph g) {
|
||||
this.g = g;
|
||||
}
|
||||
|
||||
Output div(Output x, Output y) {
|
||||
return binaryOp("Div", x, y);
|
||||
}
|
||||
|
||||
Output sub(Output x, Output y) {
|
||||
return binaryOp("Sub", x, y);
|
||||
}
|
||||
|
||||
Output resizeBilinear(Output images, Output size) {
|
||||
return binaryOp("ResizeBilinear", images, size);
|
||||
}
|
||||
|
||||
Output expandDims(Output input, Output dim) {
|
||||
return binaryOp("ExpandDims", input, dim);
|
||||
}
|
||||
|
||||
Output cast(Output value, DataType dtype) {
|
||||
return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
|
||||
}
|
||||
|
||||
Output decodeJpeg(Output contents, long channels) {
|
||||
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
|
||||
.addInput(contents)
|
||||
.setAttr("channels", channels)
|
||||
.build()
|
||||
.output(0);
|
||||
}
|
||||
|
||||
Output constant(String name, Object value) {
|
||||
try (Tensor t = Tensor.create(value)) {
|
||||
return g.opBuilder("Const", name)
|
||||
.setAttr("dtype", t.dataType())
|
||||
.setAttr("value", t)
|
||||
.build()
|
||||
.output(0);
|
||||
}
|
||||
}
|
||||
|
||||
private Output binaryOp(String type, Output in1, Output in2) {
|
||||
return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
|
||||
}
|
||||
|
||||
private Graph g;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user