STT-tensorflow/tensorflow/go/session_test.go

320 lines
7.9 KiB
Go

/*
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package tensorflow
import (
"fmt"
"reflect"
"testing"
)
func createTestGraph(t *testing.T, dt DataType) (*Graph, Output, Output) {
g := NewGraph()
inp, err := Placeholder(g, "p1", dt)
if err != nil {
t.Fatalf("Placeholder() for %v: %v", dt, err)
}
out, err := Neg(g, "neg1", inp)
if err != nil {
t.Fatalf("Neg() for %v: %v", dt, err)
}
return g, inp, out
}
func TestSessionRunNeg(t *testing.T) {
var tests = []struct {
input interface{}
expected interface{}
}{
{int64(1), int64(-1)},
{[]float64{-1, -2, 3}, []float64{1, 2, -3}},
{[][]float32{{1, -2}, {-3, 4}}, [][]float32{{-1, 2}, {3, -4}}},
}
for _, test := range tests {
t.Run(fmt.Sprint(test.input), func(t *testing.T) {
t1, err := NewTensor(test.input)
if err != nil {
t.Fatal(err)
}
graph, inp, out := createTestGraph(t, t1.DataType())
s, err := NewSession(graph, &SessionOptions{})
if err != nil {
t.Fatal(err)
}
output, err := s.Run(map[Output]*Tensor{inp: t1}, []Output{out}, []*Operation{out.Op})
if err != nil {
t.Fatal(err)
}
if len(output) != 1 {
t.Fatalf("got %d outputs, want 1", len(output))
}
val := output[0].Value()
if !reflect.DeepEqual(test.expected, val) {
t.Errorf("got %v, want %v", val, test.expected)
}
if err := s.Close(); err != nil {
t.Error(err)
}
})
}
}
func TestSessionRunConcat(t *testing.T) {
// Runs the Concat operation on two matrices: m1 and m2, along the
// first dimension (dim1).
// This tests the use of both Output and OutputList as inputs to the
// Concat operation.
var (
g = NewGraph()
dim1, _ = Const(g, "dim1", int32(1))
m1, _ = Const(g, "m1", [][]int64{
{1, 2, 3},
{4, 5, 6},
})
m2, _ = Const(g, "m2", [][]int64{
{7, 8, 9},
{10, 11, 12},
})
want = [][]int64{
{1, 2, 3, 7, 8, 9},
{4, 5, 6, 10, 11, 12},
}
)
concat, err := g.AddOperation(OpSpec{
Type: "Concat",
Input: []Input{
dim1,
OutputList{m1, m2},
},
})
if err != nil {
t.Fatal(err)
}
s, err := NewSession(g, &SessionOptions{})
if err != nil {
t.Fatal(err)
}
output, err := s.Run(nil, []Output{concat.Output(0)}, nil)
if err != nil {
t.Fatal(err)
}
if len(output) != 1 {
t.Fatal(len(output))
}
if got := output[0].Value(); !reflect.DeepEqual(got, want) {
t.Fatalf("Got %v, want %v", got, want)
}
}
func TestSessionWithStringTensors(t *testing.T) {
// Construct the graph:
// AsString(StringToHashBucketFast("PleaseHashMe")) Will be much
// prettier if using the ops package, but in this package graphs are
// constructed from first principles.
var (
g = NewGraph()
feed, _ = Const(g, "input", "PleaseHashMe")
hash, _ = g.AddOperation(OpSpec{
Type: "StringToHashBucketFast",
Input: []Input{feed},
Attrs: map[string]interface{}{
"num_buckets": int64(1 << 32),
},
})
str, _ = g.AddOperation(OpSpec{
Type: "AsString",
Input: []Input{hash.Output(0)},
})
)
s, err := NewSession(g, nil)
if err != nil {
t.Fatal(err)
}
output, err := s.Run(nil, []Output{str.Output(0)}, nil)
if err != nil {
t.Fatal(err)
}
if len(output) != 1 {
t.Fatal(len(output))
}
got, ok := output[0].Value().(string)
if !ok {
t.Fatalf("Got %T, wanted string", output[0].Value())
}
if want := "1027741475"; got != want {
t.Fatalf("Got %q, want %q", got, want)
}
}
func TestConcurrency(t *testing.T) {
tensor, err := NewTensor(int64(1))
if err != nil {
t.Fatalf("NewTensor(): %v", err)
}
graph, inp, out := createTestGraph(t, tensor.DataType())
s, err := NewSession(graph, &SessionOptions{})
if err != nil {
t.Fatalf("NewSession(): %v", err)
}
for i := 0; i < 100; i++ {
// Session may close before Run() starts, so we don't check the error.
go s.Run(map[Output]*Tensor{inp: tensor}, []Output{out}, []*Operation{out.Op})
}
if err = s.Close(); err != nil {
t.Errorf("Close() 1: %v", err)
}
if err = s.Close(); err != nil {
t.Errorf("Close() 2: %v", err)
}
}
func ExamplePartialRun() {
var (
// Create a graph: a + 2 + 3 + b.
//
// Skipping error handling for brevity of this example.
// The 'op' package can be used to make graph construction code
// with error handling more succinct.
g = NewGraph()
a, _ = Placeholder(g, "a", Int32)
b, _ = Placeholder(g, "b", Int32)
two, _ = Const(g, "Two", int32(2))
three, _ = Const(g, "Three", int32(3))
plus2, _ = Add(g, "plus2", a, two) // a + 2
plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3
plusB, _ = Add(g, "plusB", plus3, b) // ((a + 2) + 3) + b
)
sess, err := NewSession(g, nil)
if err != nil {
panic(err)
}
defer sess.Close()
// All the feeds, fetches and targets for subsequent PartialRun.Run
// calls must be provided at setup.
pr, err := sess.NewPartialRun(
[]Output{a, b},
[]Output{plus2, plusB},
[]*Operation{plus3.Op},
)
if err != nil {
panic(err)
}
// Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'.
// Imagine this to be the forward pass of unsupervised neural network
// training of a robot.
val, _ := NewTensor(int32(1))
fetches, err := pr.Run(
map[Output]*Tensor{a: val},
[]Output{plus2},
nil)
if err != nil {
panic(err)
}
v1 := fetches[0].Value().(int32)
// Now, feed 'b=4', fetch 'plusB=a+2+3+b'
// Imagine this to be the result of actuating the robot to determine
// the error produced by the current state of the neural network.
val, _ = NewTensor(int32(4))
fetches, err = pr.Run(
map[Output]*Tensor{b: val},
[]Output{plusB},
nil)
if err != nil {
panic(err)
}
v2 := fetches[0].Value().(int32)
fmt.Println(v1, v2)
// Output: 3 10
}
func TestSessionConfig(t *testing.T) {
// Exercise SessionOptions.
// Arguably, a better API would be for SessionOptions.Config to be the
// type generated by the protocol buffer compiler. But for now, the
// tensorflow package continues to be independent of protocol buffers
// and this test exercises the option since the implementation has a
// nuanced conversion to C types.
//
// Till then, the []byte form of Config here was generated using a toy
// tensorflow Python program:
/*
import tensorflow
c = tensorflow.ConfigProto()
c.intra_op_parallelism_threads = 1
print c.SerializeToString()
*/
graph := NewGraph()
c, err := Const(graph, "Const", int32(14))
if err != nil {
t.Fatal(err)
}
opts := SessionOptions{Config: []byte("(\x01")}
s, err := NewSession(graph, &opts)
if err != nil {
t.Fatal(err)
}
output, err := s.Run(nil, []Output{c}, nil)
if err != nil {
t.Fatal(err)
}
if output[0].Value().(int32) != 14 {
t.Fatalf("Got %v, want -1", output[0].Value())
}
}
func TestListDevices(t *testing.T) {
s, err := NewSession(NewGraph(), nil)
if err != nil {
t.Fatalf("NewSession(): %v", err)
}
devices, err := s.ListDevices()
if err != nil {
t.Fatalf("ListDevices(): %v", err)
}
if len(devices) == 0 {
t.Fatalf("no devices detected")
}
}
func TestDeviceString(t *testing.T) {
d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: 12345}
got := d.String()
want := "(Device: name \"foo\", type bar, memory limit 12345 bytes)"
if got != want {
t.Errorf("Got \"%s\", want \"%s\"", got, want)
}
}
func TestDeviceStringNoMemoryLimit(t *testing.T) {
d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: -1}
got := d.String()
want := "(Device: name \"foo\", type bar, no memory limit)"
if got != want {
t.Errorf("Got \"%s\", want \"%s\"", got, want)
}
}