STT-tensorflow/tensorflow/go/operation_test.go

270 lines
6.2 KiB
Go

/*
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package tensorflow
import (
"fmt"
"runtime"
"runtime/debug"
"testing"
)
// createGraphAndOp creates an Operation but loses the reference to the Graph.
func createGraphAndOp() (*Operation, error) {
t, err := NewTensor(int64(1))
if err != nil {
return nil, err
}
g := NewGraph()
output, err := Placeholder(g, "my_placeholder", t.DataType())
if err != nil {
return nil, err
}
return output.Op, nil
}
func TestOperationLifetime(t *testing.T) {
// Ensure that the Graph is not garbage collected while the program
// still has access to the Operation.
op, err := createGraphAndOp()
if err != nil {
t.Fatal(err)
}
forceGC()
if got, want := op.Name(), "my_placeholder"; got != want {
t.Errorf("Got '%s', want '%s'", got, want)
}
if got, want := op.Type(), "Placeholder"; got != want {
t.Errorf("Got '%s', want '%s'", got, want)
}
}
func TestOperationOutputListSize(t *testing.T) {
graph := NewGraph()
c1, err := Const(graph, "c1", int64(1))
if err != nil {
t.Fatal(err)
}
c2, err := Const(graph, "c2", [][]int64{{1, 2}, {3, 4}})
if err != nil {
t.Fatal(err)
}
// The ShapeN op takes a list of tensors as input and a list as output.
op, err := graph.AddOperation(OpSpec{
Type: "ShapeN",
Input: []Input{OutputList{c1, c2}},
})
if err != nil {
t.Fatal(err)
}
n, err := op.OutputListSize("output")
if err != nil {
t.Fatal(err)
}
if got, want := n, 2; got != want {
t.Errorf("Got %d, want %d", got, want)
}
if got, want := op.NumOutputs(), 2; got != want {
t.Errorf("Got %d, want %d", got, want)
}
}
func TestOperationShapeAttribute(t *testing.T) {
g := NewGraph()
_, err := g.AddOperation(OpSpec{
Type: "Placeholder",
Attrs: map[string]interface{}{
"dtype": Float,
"shape": MakeShape(-1, 3),
},
})
if err != nil {
t.Fatal(err)
}
// If and when the API to get attributes is added, check that here.
}
func TestOutputDataTypeAndShape(t *testing.T) {
graph := NewGraph()
testdata := []struct {
Value interface{}
Shape []int64
dtype DataType
}{
{ // Scalar
int64(0),
[]int64{},
Int64,
},
{ // Vector
[]int32{1, 2, 3},
[]int64{3},
Int32,
},
{ // Matrix
[][]float64{
{1, 2, 3},
{4, 5, 6},
},
[]int64{2, 3},
Double,
},
{ // Matrix of Uint64
[][]uint64{
{1, 2, 3},
{4, 5, 6},
},
[]int64{2, 3},
Uint64,
},
}
for idx, test := range testdata {
t.Run(fmt.Sprintf("#%d Value %T", idx, test.Value), func(t *testing.T) {
c, err := Const(graph, fmt.Sprintf("const%d", idx), test.Value)
if err != nil {
t.Fatal(err)
}
if got, want := c.DataType(), test.dtype; got != want {
t.Errorf("Got DataType %v, want %v", got, want)
}
shape := c.Shape()
if got, want := shape.NumDimensions(), len(test.Shape); got != want {
t.Fatalf("Got a shape with %d dimensions, want %d", got, want)
}
for i := 0; i < len(test.Shape); i++ {
if got, want := shape.Size(i), test.Shape[i]; got != want {
t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(test.Shape))
}
}
})
}
// Unknown number of dimensions
dummyTensor, err := NewTensor(float64(0))
if err != nil {
t.Fatal(err)
}
placeholder, err := Placeholder(graph, "placeholder", dummyTensor.DataType())
if err != nil {
t.Fatal(err)
}
if shape := placeholder.Shape(); shape.NumDimensions() != -1 {
t.Errorf("Got shape %v, wanted an unknown number of dimensions", shape)
}
}
func TestOperationInputs(t *testing.T) {
g := NewGraph()
x, err := Placeholder(g, "x", Float)
if err != nil {
t.Fatal(err)
}
y, err := Placeholder(g, "y", Float)
if err != nil {
t.Fatal(err)
}
add, err := Add(g, "add", x, y)
if err != nil {
t.Fatal(err)
}
addOp := add.Op
if out := addOp.NumInputs(); out != 2 {
t.Fatalf("Got %d inputs, wanted 2", out)
}
}
func TestOperationConsumers(t *testing.T) {
g := NewGraph()
x, err := Placeholder(g, "x", Float)
if err != nil {
t.Fatal(err)
}
a, err := Neg(g, "a", x)
if err != nil {
t.Fatal(err)
}
b, err := Neg(g, "b", x)
if err != nil {
t.Fatal(err)
}
consumers := []*Operation{a.Op, b.Op}
xConsumers := x.Consumers()
if out := len(xConsumers); out != 2 {
t.Fatalf("Got %d consumers, wanted 2", out)
}
for i, consumer := range xConsumers {
got := consumer.Op.Name()
want := consumers[i].Name()
if got != want {
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
}
got = consumer.Producer().Op.Name()
want = x.Op.Name()
if got != want {
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
}
}
if len(b.Consumers()) != 0 {
t.Fatalf("expected %+v to have no consumers", b)
}
}
func TestOperationDevice(t *testing.T) {
graph := NewGraph()
v, err := NewTensor(float32(1.0))
if err != nil {
t.Fatal(err)
}
op, err := graph.AddOperation(OpSpec{
Type: "Const",
Name: "Const",
Attrs: map[string]interface{}{
"dtype": v.DataType(),
"value": v,
},
Device: "/device:GPU:0",
})
if err != nil {
t.Fatal(err)
}
if got, want := op.Device(), "/device:GPU:0"; got != want {
t.Errorf("Got %q, want %q", got, want)
}
}
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
// It was empirically observed that without this extra allocation
// TestOperationLifetime would fail only 50% of the time if
// Operation did not hold on to a reference to Graph. With this
// additional allocation, and with the bug where Operation does
// not hold onto a Graph, the test failed 90+% of the time.
//
// The author is aware that this technique is potentially fragile
// and fishy. Suggestions for alternatives are welcome.
bytesTillGC := mem.NextGC - mem.HeapAlloc + 1
_ = make([]byte, bytesTillGC)
runtime.GC()
debug.FreeOSMemory()
}