320 lines
7.9 KiB
Go
320 lines
7.9 KiB
Go
/*
|
|
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tensorflow
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
)
|
|
|
|
func createTestGraph(t *testing.T, dt DataType) (*Graph, Output, Output) {
|
|
g := NewGraph()
|
|
inp, err := Placeholder(g, "p1", dt)
|
|
if err != nil {
|
|
t.Fatalf("Placeholder() for %v: %v", dt, err)
|
|
}
|
|
out, err := Neg(g, "neg1", inp)
|
|
if err != nil {
|
|
t.Fatalf("Neg() for %v: %v", dt, err)
|
|
}
|
|
return g, inp, out
|
|
}
|
|
|
|
func TestSessionRunNeg(t *testing.T) {
|
|
var tests = []struct {
|
|
input interface{}
|
|
expected interface{}
|
|
}{
|
|
{int64(1), int64(-1)},
|
|
{[]float64{-1, -2, 3}, []float64{1, 2, -3}},
|
|
{[][]float32{{1, -2}, {-3, 4}}, [][]float32{{-1, 2}, {3, -4}}},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(fmt.Sprint(test.input), func(t *testing.T) {
|
|
t1, err := NewTensor(test.input)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
graph, inp, out := createTestGraph(t, t1.DataType())
|
|
s, err := NewSession(graph, &SessionOptions{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
output, err := s.Run(map[Output]*Tensor{inp: t1}, []Output{out}, []*Operation{out.Op})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(output) != 1 {
|
|
t.Fatalf("got %d outputs, want 1", len(output))
|
|
}
|
|
val := output[0].Value()
|
|
if !reflect.DeepEqual(test.expected, val) {
|
|
t.Errorf("got %v, want %v", val, test.expected)
|
|
}
|
|
if err := s.Close(); err != nil {
|
|
t.Error(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSessionRunConcat(t *testing.T) {
|
|
// Runs the Concat operation on two matrices: m1 and m2, along the
|
|
// first dimension (dim1).
|
|
// This tests the use of both Output and OutputList as inputs to the
|
|
// Concat operation.
|
|
var (
|
|
g = NewGraph()
|
|
dim1, _ = Const(g, "dim1", int32(1))
|
|
m1, _ = Const(g, "m1", [][]int64{
|
|
{1, 2, 3},
|
|
{4, 5, 6},
|
|
})
|
|
m2, _ = Const(g, "m2", [][]int64{
|
|
{7, 8, 9},
|
|
{10, 11, 12},
|
|
})
|
|
want = [][]int64{
|
|
{1, 2, 3, 7, 8, 9},
|
|
{4, 5, 6, 10, 11, 12},
|
|
}
|
|
)
|
|
concat, err := g.AddOperation(OpSpec{
|
|
Type: "Concat",
|
|
Input: []Input{
|
|
dim1,
|
|
OutputList{m1, m2},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
s, err := NewSession(g, &SessionOptions{})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
output, err := s.Run(nil, []Output{concat.Output(0)}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(output) != 1 {
|
|
t.Fatal(len(output))
|
|
}
|
|
if got := output[0].Value(); !reflect.DeepEqual(got, want) {
|
|
t.Fatalf("Got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestSessionWithStringTensors(t *testing.T) {
|
|
// Construct the graph:
|
|
// AsString(StringToHashBucketFast("PleaseHashMe")) Will be much
|
|
// prettier if using the ops package, but in this package graphs are
|
|
// constructed from first principles.
|
|
var (
|
|
g = NewGraph()
|
|
feed, _ = Const(g, "input", "PleaseHashMe")
|
|
hash, _ = g.AddOperation(OpSpec{
|
|
Type: "StringToHashBucketFast",
|
|
Input: []Input{feed},
|
|
Attrs: map[string]interface{}{
|
|
"num_buckets": int64(1 << 32),
|
|
},
|
|
})
|
|
str, _ = g.AddOperation(OpSpec{
|
|
Type: "AsString",
|
|
Input: []Input{hash.Output(0)},
|
|
})
|
|
)
|
|
s, err := NewSession(g, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
output, err := s.Run(nil, []Output{str.Output(0)}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(output) != 1 {
|
|
t.Fatal(len(output))
|
|
}
|
|
got, ok := output[0].Value().(string)
|
|
if !ok {
|
|
t.Fatalf("Got %T, wanted string", output[0].Value())
|
|
}
|
|
if want := "1027741475"; got != want {
|
|
t.Fatalf("Got %q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
func TestConcurrency(t *testing.T) {
|
|
tensor, err := NewTensor(int64(1))
|
|
if err != nil {
|
|
t.Fatalf("NewTensor(): %v", err)
|
|
}
|
|
|
|
graph, inp, out := createTestGraph(t, tensor.DataType())
|
|
s, err := NewSession(graph, &SessionOptions{})
|
|
if err != nil {
|
|
t.Fatalf("NewSession(): %v", err)
|
|
}
|
|
for i := 0; i < 100; i++ {
|
|
// Session may close before Run() starts, so we don't check the error.
|
|
go s.Run(map[Output]*Tensor{inp: tensor}, []Output{out}, []*Operation{out.Op})
|
|
}
|
|
if err = s.Close(); err != nil {
|
|
t.Errorf("Close() 1: %v", err)
|
|
}
|
|
if err = s.Close(); err != nil {
|
|
t.Errorf("Close() 2: %v", err)
|
|
}
|
|
}
|
|
|
|
func ExamplePartialRun() {
|
|
var (
|
|
// Create a graph: a + 2 + 3 + b.
|
|
//
|
|
// Skipping error handling for brevity of this example.
|
|
// The 'op' package can be used to make graph construction code
|
|
// with error handling more succinct.
|
|
g = NewGraph()
|
|
a, _ = Placeholder(g, "a", Int32)
|
|
b, _ = Placeholder(g, "b", Int32)
|
|
two, _ = Const(g, "Two", int32(2))
|
|
three, _ = Const(g, "Three", int32(3))
|
|
|
|
plus2, _ = Add(g, "plus2", a, two) // a + 2
|
|
plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3
|
|
plusB, _ = Add(g, "plusB", plus3, b) // ((a + 2) + 3) + b
|
|
|
|
)
|
|
sess, err := NewSession(g, nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer sess.Close()
|
|
|
|
// All the feeds, fetches and targets for subsequent PartialRun.Run
|
|
// calls must be provided at setup.
|
|
pr, err := sess.NewPartialRun(
|
|
[]Output{a, b},
|
|
[]Output{plus2, plusB},
|
|
[]*Operation{plus3.Op},
|
|
)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
// Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'.
|
|
// Imagine this to be the forward pass of unsupervised neural network
|
|
// training of a robot.
|
|
val, _ := NewTensor(int32(1))
|
|
fetches, err := pr.Run(
|
|
map[Output]*Tensor{a: val},
|
|
[]Output{plus2},
|
|
nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
v1 := fetches[0].Value().(int32)
|
|
|
|
// Now, feed 'b=4', fetch 'plusB=a+2+3+b'
|
|
// Imagine this to be the result of actuating the robot to determine
|
|
// the error produced by the current state of the neural network.
|
|
val, _ = NewTensor(int32(4))
|
|
fetches, err = pr.Run(
|
|
map[Output]*Tensor{b: val},
|
|
[]Output{plusB},
|
|
nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
v2 := fetches[0].Value().(int32)
|
|
|
|
fmt.Println(v1, v2)
|
|
// Output: 3 10
|
|
}
|
|
|
|
func TestSessionConfig(t *testing.T) {
|
|
// Exercise SessionOptions.
|
|
// Arguably, a better API would be for SessionOptions.Config to be the
|
|
// type generated by the protocol buffer compiler. But for now, the
|
|
// tensorflow package continues to be independent of protocol buffers
|
|
// and this test exercises the option since the implementation has a
|
|
// nuanced conversion to C types.
|
|
//
|
|
// Till then, the []byte form of Config here was generated using a toy
|
|
// tensorflow Python program:
|
|
/*
|
|
import tensorflow
|
|
c = tensorflow.ConfigProto()
|
|
c.intra_op_parallelism_threads = 1
|
|
print c.SerializeToString()
|
|
*/
|
|
graph := NewGraph()
|
|
c, err := Const(graph, "Const", int32(14))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
opts := SessionOptions{Config: []byte("(\x01")}
|
|
s, err := NewSession(graph, &opts)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
output, err := s.Run(nil, []Output{c}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if output[0].Value().(int32) != 14 {
|
|
t.Fatalf("Got %v, want -1", output[0].Value())
|
|
}
|
|
}
|
|
|
|
func TestListDevices(t *testing.T) {
|
|
s, err := NewSession(NewGraph(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewSession(): %v", err)
|
|
}
|
|
|
|
devices, err := s.ListDevices()
|
|
if err != nil {
|
|
t.Fatalf("ListDevices(): %v", err)
|
|
}
|
|
|
|
if len(devices) == 0 {
|
|
t.Fatalf("no devices detected")
|
|
}
|
|
}
|
|
|
|
func TestDeviceString(t *testing.T) {
|
|
d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: 12345}
|
|
got := d.String()
|
|
want := "(Device: name \"foo\", type bar, memory limit 12345 bytes)"
|
|
if got != want {
|
|
t.Errorf("Got \"%s\", want \"%s\"", got, want)
|
|
}
|
|
}
|
|
|
|
func TestDeviceStringNoMemoryLimit(t *testing.T) {
|
|
d := Device{Name: "foo", Type: "bar", MemoryLimitBytes: -1}
|
|
got := d.String()
|
|
want := "(Device: name \"foo\", type bar, no memory limit)"
|
|
if got != want {
|
|
t.Errorf("Got \"%s\", want \"%s\"", got, want)
|
|
}
|
|
}
|