This fix fixes several issues related to `go fmt` and `go lint` based on https://goreportcard.com/report/github.com/tensorflow/tensorflow There are several changes: - `gofmt -s tensorflow/go/tensor.go` - `gofmt -s tensorflow/go/example_inception_inference_test.go` - `golint tensorflow/go/genop/internal/lib.go` At the moment there are still quite a few golint and ineffassign warnings in the current go code base. However, all of them are from `tensorflow/go/op/wrappers.go` which is machine generated code. This fix does not cover `tensorflow/go/op/wrappers.go`. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
292 lines
9.0 KiB
Go
292 lines
9.0 KiB
Go
/*
|
|
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tensorflow_test
|
|
|
|
import (
|
|
"archive/zip"
|
|
"bufio"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
|
"github.com/tensorflow/tensorflow/tensorflow/go/op"
|
|
)
|
|
|
|
func Example() {
|
|
// An example for using the TensorFlow Go API for image recognition
|
|
// using a pre-trained inception model (http://arxiv.org/abs/1512.00567).
|
|
//
|
|
// Sample usage: <program> -dir=/tmp/modeldir -image=/path/to/some/jpeg
|
|
//
|
|
// The pre-trained model takes input in the form of a 4-dimensional
|
|
// tensor with shape [ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3 ],
|
|
// where:
|
|
// - BATCH_SIZE allows for inference of multiple images in one pass through the graph
|
|
// - IMAGE_HEIGHT is the height of the images on which the model was trained
|
|
// - IMAGE_WIDTH is the width of the images on which the model was trained
|
|
// - 3 is the (R, G, B) values of the pixel colors represented as a float.
|
|
//
|
|
// And produces as output a vector with shape [ NUM_LABELS ].
|
|
// output[i] is the probability that the input image was recognized as
|
|
// having the i-th label.
|
|
//
|
|
// A separate file contains a list of string labels corresponding to the
|
|
// integer indices of the output.
|
|
//
|
|
// This example:
|
|
// - Loads the serialized representation of the pre-trained model into a Graph
|
|
// - Creates a Session to execute operations on the Graph
|
|
// - Converts an image file to a Tensor to provide as input to a Session run
|
|
// - Executes the Session and prints out the label with the highest probability
|
|
//
|
|
// To convert an image file to a Tensor suitable for input to the Inception model,
|
|
// this example:
|
|
// - Constructs another TensorFlow graph to normalize the image into a
|
|
// form suitable for the model (for example, resizing the image)
|
|
// - Creates and executes a Session to obtain a Tensor in this normalized form.
|
|
modeldir := flag.String("dir", "", "Directory containing the trained model files. The directory will be created and the model downloaded into it if necessary")
|
|
imagefile := flag.String("image", "", "Path of a JPEG-image to extract labels for")
|
|
flag.Parse()
|
|
if *modeldir == "" || *imagefile == "" {
|
|
flag.Usage()
|
|
return
|
|
}
|
|
// Load the serialized GraphDef from a file.
|
|
modelfile, labelsfile, err := modelFiles(*modeldir)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
model, err := ioutil.ReadFile(modelfile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Construct an in-memory graph from the serialized form.
|
|
graph := tf.NewGraph()
|
|
if err := graph.Import(model, ""); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Create a session for inference over graph.
|
|
session, err := tf.NewSession(graph, nil)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// Run inference on *imageFile.
|
|
// For multiple images, session.Run() can be called in a loop (and
|
|
// concurrently). Alternatively, images can be batched since the model
|
|
// accepts batches of image data as input.
|
|
tensor, err := makeTensorFromImage(*imagefile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
output, err := session.Run(
|
|
map[tf.Output]*tf.Tensor{
|
|
graph.Operation("input").Output(0): tensor,
|
|
},
|
|
[]tf.Output{
|
|
graph.Operation("output").Output(0),
|
|
},
|
|
nil)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
// output[0].Value() is a vector containing probabilities of
|
|
// labels for each image in the "batch". The batch size was 1.
|
|
// Find the most probably label index.
|
|
probabilities := output[0].Value().([][]float32)[0]
|
|
printBestLabel(probabilities, labelsfile)
|
|
}
|
|
|
|
func printBestLabel(probabilities []float32, labelsFile string) {
|
|
bestIdx := 0
|
|
for i, p := range probabilities {
|
|
if p > probabilities[bestIdx] {
|
|
bestIdx = i
|
|
}
|
|
}
|
|
// Found the best match. Read the string from labelsFile, which
|
|
// contains one line per label.
|
|
file, err := os.Open(labelsFile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
defer file.Close()
|
|
scanner := bufio.NewScanner(file)
|
|
var labels []string
|
|
for scanner.Scan() {
|
|
labels = append(labels, scanner.Text())
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
log.Printf("ERROR: failed to read %s: %v", labelsFile, err)
|
|
}
|
|
fmt.Printf("BEST MATCH: (%2.0f%% likely) %s\n", probabilities[bestIdx]*100.0, labels[bestIdx])
|
|
}
|
|
|
|
// Convert the image in filename to a Tensor suitable as input to the Inception model.
|
|
func makeTensorFromImage(filename string) (*tf.Tensor, error) {
|
|
bytes, err := ioutil.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// DecodeJpeg uses a scalar String-valued tensor as input.
|
|
tensor, err := tf.NewTensor(string(bytes))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Construct a graph to normalize the image
|
|
graph, input, output, err := constructGraphToNormalizeImage()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Execute that graph to normalize this one image
|
|
session, err := tf.NewSession(graph, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer session.Close()
|
|
normalized, err := session.Run(
|
|
map[tf.Output]*tf.Tensor{input: tensor},
|
|
[]tf.Output{output},
|
|
nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return normalized[0], nil
|
|
}
|
|
|
|
// The inception model takes as input the image described by a Tensor in a very
|
|
// specific normalized format (a particular image size, shape of the input tensor,
|
|
// normalized pixel values etc.).
|
|
//
|
|
// This function constructs a graph of TensorFlow operations which takes as
|
|
// input a JPEG-encoded string and returns a tensor suitable as input to the
|
|
// inception model.
|
|
func constructGraphToNormalizeImage() (graph *tf.Graph, input, output tf.Output, err error) {
|
|
// Some constants specific to the pre-trained model at:
|
|
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
|
|
//
|
|
// - The model was trained after with images scaled to 224x224 pixels.
|
|
// - The colors, represented as R, G, B in 1-byte each were converted to
|
|
// float using (value - Mean)/Scale.
|
|
const (
|
|
H, W = 224, 224
|
|
Mean = float32(117)
|
|
Scale = float32(1)
|
|
)
|
|
// - input is a String-Tensor, where the string the JPEG-encoded image.
|
|
// - The inception model takes a 4D tensor of shape
|
|
// [BatchSize, Height, Width, Colors=3], where each pixel is
|
|
// represented as a triplet of floats
|
|
// - Apply normalization on each pixel and use ExpandDims to make
|
|
// this single image be a "batch" of size 1 for ResizeBilinear.
|
|
s := op.NewScope()
|
|
input = op.Placeholder(s, tf.String)
|
|
output = op.Div(s,
|
|
op.Sub(s,
|
|
op.ResizeBilinear(s,
|
|
op.ExpandDims(s,
|
|
op.Cast(s,
|
|
op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)), tf.Float),
|
|
op.Const(s.SubScope("make_batch"), int32(0))),
|
|
op.Const(s.SubScope("size"), []int32{H, W})),
|
|
op.Const(s.SubScope("mean"), Mean)),
|
|
op.Const(s.SubScope("scale"), Scale))
|
|
graph, err = s.Finalize()
|
|
return graph, input, output, err
|
|
}
|
|
|
|
func modelFiles(dir string) (modelfile, labelsfile string, err error) {
|
|
const URL = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"
|
|
var (
|
|
model = filepath.Join(dir, "tensorflow_inception_graph.pb")
|
|
labels = filepath.Join(dir, "imagenet_comp_graph_label_strings.txt")
|
|
zipfile = filepath.Join(dir, "inception5h.zip")
|
|
)
|
|
if filesExist(model, labels) == nil {
|
|
return model, labels, nil
|
|
}
|
|
log.Println("Did not find model in", dir, "downloading from", URL)
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
return "", "", err
|
|
}
|
|
if err := download(URL, zipfile); err != nil {
|
|
return "", "", fmt.Errorf("failed to download %v - %v", URL, err)
|
|
}
|
|
if err := unzip(dir, zipfile); err != nil {
|
|
return "", "", fmt.Errorf("failed to extract contents from model archive: %v", err)
|
|
}
|
|
os.Remove(zipfile)
|
|
return model, labels, filesExist(model, labels)
|
|
}
|
|
|
|
func filesExist(files ...string) error {
|
|
for _, f := range files {
|
|
if _, err := os.Stat(f); err != nil {
|
|
return fmt.Errorf("unable to stat %s: %v", f, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func download(URL, filename string) error {
|
|
resp, err := http.Get(URL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer file.Close()
|
|
_, err = io.Copy(file, resp.Body)
|
|
return err
|
|
}
|
|
|
|
func unzip(dir, zipfile string) error {
|
|
r, err := zip.OpenReader(zipfile)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer r.Close()
|
|
for _, f := range r.File {
|
|
src, err := f.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Println("Extracting", f.Name)
|
|
dst, err := os.OpenFile(filepath.Join(dir, f.Name), os.O_WRONLY|os.O_CREATE, 0644)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := io.Copy(dst, src); err != nil {
|
|
return err
|
|
}
|
|
dst.Close()
|
|
}
|
|
return nil
|
|
}
|