270 lines
6.2 KiB
Go
270 lines
6.2 KiB
Go
/*
|
|
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tensorflow
|
|
|
|
import (
|
|
"fmt"
|
|
"runtime"
|
|
"runtime/debug"
|
|
"testing"
|
|
)
|
|
|
|
// createGraphAndOp creates an Operation but loses the reference to the Graph.
|
|
func createGraphAndOp() (*Operation, error) {
|
|
t, err := NewTensor(int64(1))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
g := NewGraph()
|
|
output, err := Placeholder(g, "my_placeholder", t.DataType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return output.Op, nil
|
|
}
|
|
|
|
func TestOperationLifetime(t *testing.T) {
|
|
// Ensure that the Graph is not garbage collected while the program
|
|
// still has access to the Operation.
|
|
op, err := createGraphAndOp()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
forceGC()
|
|
if got, want := op.Name(), "my_placeholder"; got != want {
|
|
t.Errorf("Got '%s', want '%s'", got, want)
|
|
}
|
|
if got, want := op.Type(), "Placeholder"; got != want {
|
|
t.Errorf("Got '%s', want '%s'", got, want)
|
|
}
|
|
}
|
|
|
|
func TestOperationOutputListSize(t *testing.T) {
|
|
graph := NewGraph()
|
|
c1, err := Const(graph, "c1", int64(1))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
c2, err := Const(graph, "c2", [][]int64{{1, 2}, {3, 4}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// The ShapeN op takes a list of tensors as input and a list as output.
|
|
op, err := graph.AddOperation(OpSpec{
|
|
Type: "ShapeN",
|
|
Input: []Input{OutputList{c1, c2}},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
n, err := op.OutputListSize("output")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got, want := n, 2; got != want {
|
|
t.Errorf("Got %d, want %d", got, want)
|
|
}
|
|
if got, want := op.NumOutputs(), 2; got != want {
|
|
t.Errorf("Got %d, want %d", got, want)
|
|
}
|
|
}
|
|
|
|
func TestOperationShapeAttribute(t *testing.T) {
|
|
g := NewGraph()
|
|
_, err := g.AddOperation(OpSpec{
|
|
Type: "Placeholder",
|
|
Attrs: map[string]interface{}{
|
|
"dtype": Float,
|
|
"shape": MakeShape(-1, 3),
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// If and when the API to get attributes is added, check that here.
|
|
}
|
|
|
|
func TestOutputDataTypeAndShape(t *testing.T) {
|
|
graph := NewGraph()
|
|
testdata := []struct {
|
|
Value interface{}
|
|
Shape []int64
|
|
dtype DataType
|
|
}{
|
|
{ // Scalar
|
|
int64(0),
|
|
[]int64{},
|
|
Int64,
|
|
},
|
|
{ // Vector
|
|
[]int32{1, 2, 3},
|
|
[]int64{3},
|
|
Int32,
|
|
},
|
|
{ // Matrix
|
|
[][]float64{
|
|
{1, 2, 3},
|
|
{4, 5, 6},
|
|
},
|
|
[]int64{2, 3},
|
|
Double,
|
|
},
|
|
{ // Matrix of Uint64
|
|
[][]uint64{
|
|
{1, 2, 3},
|
|
{4, 5, 6},
|
|
},
|
|
[]int64{2, 3},
|
|
Uint64,
|
|
},
|
|
}
|
|
for idx, test := range testdata {
|
|
t.Run(fmt.Sprintf("#%d Value %T", idx, test.Value), func(t *testing.T) {
|
|
c, err := Const(graph, fmt.Sprintf("const%d", idx), test.Value)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got, want := c.DataType(), test.dtype; got != want {
|
|
t.Errorf("Got DataType %v, want %v", got, want)
|
|
}
|
|
shape := c.Shape()
|
|
if got, want := shape.NumDimensions(), len(test.Shape); got != want {
|
|
t.Fatalf("Got a shape with %d dimensions, want %d", got, want)
|
|
}
|
|
for i := 0; i < len(test.Shape); i++ {
|
|
if got, want := shape.Size(i), test.Shape[i]; got != want {
|
|
t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(test.Shape))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
// Unknown number of dimensions
|
|
dummyTensor, err := NewTensor(float64(0))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
placeholder, err := Placeholder(graph, "placeholder", dummyTensor.DataType())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if shape := placeholder.Shape(); shape.NumDimensions() != -1 {
|
|
t.Errorf("Got shape %v, wanted an unknown number of dimensions", shape)
|
|
}
|
|
}
|
|
|
|
func TestOperationInputs(t *testing.T) {
|
|
g := NewGraph()
|
|
x, err := Placeholder(g, "x", Float)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
y, err := Placeholder(g, "y", Float)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
add, err := Add(g, "add", x, y)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
addOp := add.Op
|
|
|
|
if out := addOp.NumInputs(); out != 2 {
|
|
t.Fatalf("Got %d inputs, wanted 2", out)
|
|
}
|
|
}
|
|
|
|
func TestOperationConsumers(t *testing.T) {
|
|
g := NewGraph()
|
|
x, err := Placeholder(g, "x", Float)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
a, err := Neg(g, "a", x)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
b, err := Neg(g, "b", x)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
consumers := []*Operation{a.Op, b.Op}
|
|
|
|
xConsumers := x.Consumers()
|
|
if out := len(xConsumers); out != 2 {
|
|
t.Fatalf("Got %d consumers, wanted 2", out)
|
|
}
|
|
|
|
for i, consumer := range xConsumers {
|
|
got := consumer.Op.Name()
|
|
want := consumers[i].Name()
|
|
if got != want {
|
|
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
|
|
}
|
|
|
|
got = consumer.Producer().Op.Name()
|
|
want = x.Op.Name()
|
|
if got != want {
|
|
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
|
|
}
|
|
}
|
|
|
|
if len(b.Consumers()) != 0 {
|
|
t.Fatalf("expected %+v to have no consumers", b)
|
|
}
|
|
}
|
|
|
|
func TestOperationDevice(t *testing.T) {
|
|
graph := NewGraph()
|
|
v, err := NewTensor(float32(1.0))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
op, err := graph.AddOperation(OpSpec{
|
|
Type: "Const",
|
|
Name: "Const",
|
|
Attrs: map[string]interface{}{
|
|
"dtype": v.DataType(),
|
|
"value": v,
|
|
},
|
|
Device: "/device:GPU:0",
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got, want := op.Device(), "/device:GPU:0"; got != want {
|
|
t.Errorf("Got %q, want %q", got, want)
|
|
}
|
|
}
|
|
|
|
func forceGC() {
|
|
var mem runtime.MemStats
|
|
runtime.ReadMemStats(&mem)
|
|
// It was empirically observed that without this extra allocation
|
|
// TestOperationLifetime would fail only 50% of the time if
|
|
// Operation did not hold on to a reference to Graph. With this
|
|
// additional allocation, and with the bug where Operation does
|
|
// not hold onto a Graph, the test failed 90+% of the time.
|
|
//
|
|
// The author is aware that this technique is potentially fragile
|
|
// and fishy. Suggestions for alternatives are welcome.
|
|
bytesTillGC := mem.NextGC - mem.HeapAlloc + 1
|
|
_ = make([]byte, bytesTillGC)
|
|
runtime.GC()
|
|
debug.FreeOSMemory()
|
|
}
|