194 lines
4.0 KiB
Go
194 lines
4.0 KiB
Go
/*
|
|
Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tensorflow
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
)
|
|
|
|
func TestOperationAttrs(t *testing.T) {
|
|
g := NewGraph()
|
|
|
|
i := 0
|
|
makeConst := func(v interface{}) Output {
|
|
op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
|
|
i++
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return op
|
|
}
|
|
|
|
makeTensor := func(v interface{}) *Tensor {
|
|
tensor, err := NewTensor(v)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return tensor
|
|
}
|
|
|
|
cases := []OpSpec{
|
|
{
|
|
Name: "type",
|
|
Type: "Placeholder",
|
|
Attrs: map[string]interface{}{
|
|
"dtype": Float,
|
|
},
|
|
},
|
|
{
|
|
Name: "list(float)",
|
|
Type: "Bucketize",
|
|
Input: []Input{
|
|
makeConst([]float32{1, 2, 3, 4}),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"boundaries": []float32{0, 1, 2, 3, 4, 5},
|
|
},
|
|
},
|
|
{
|
|
Name: "list(float) empty",
|
|
Type: "Bucketize",
|
|
Input: []Input{
|
|
makeConst([]float32{}),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"boundaries": []float32(nil),
|
|
},
|
|
},
|
|
/* TODO(ashankar): debug this issue and add it back later.
|
|
{
|
|
Name: "list(type),list(shape)",
|
|
Type: "InfeedEnqueueTuple",
|
|
Input: []Input{
|
|
OutputList([]Output{
|
|
makeConst(float32(1)),
|
|
makeConst([][]int32{{2}}),
|
|
}),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"dtypes": []DataType{Float, Int32},
|
|
"shapes": []Shape{ScalarShape(), MakeShape(1, 1)},
|
|
},
|
|
},
|
|
{
|
|
Name: "list(type),list(shape) empty",
|
|
Type: "InfeedEnqueueTuple",
|
|
Input: []Input{
|
|
OutputList([]Output{
|
|
makeConst([][]int32{{2}}),
|
|
}),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"dtypes": []DataType{Int32},
|
|
"shapes": []Shape(nil),
|
|
},
|
|
},
|
|
{
|
|
Name: "list(type) empty,string empty,int",
|
|
Type: "_XlaSendFromHost",
|
|
Input: []Input{
|
|
OutputList([]Output{}),
|
|
makeConst(""),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"Tinputs": []DataType(nil),
|
|
"key": "",
|
|
"device_ordinal": int64(0),
|
|
},
|
|
},
|
|
*/
|
|
{
|
|
Name: "list(int),int",
|
|
Type: "StringToHashBucketStrong",
|
|
Input: []Input{
|
|
makeConst(""),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"num_buckets": int64(2),
|
|
"key": []int64{1, 2},
|
|
},
|
|
},
|
|
{
|
|
Name: "list(int) empty,int",
|
|
Type: "StringToHashBucketStrong",
|
|
Input: []Input{
|
|
makeConst(""),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"num_buckets": int64(2),
|
|
"key": ([]int64)(nil),
|
|
},
|
|
},
|
|
{
|
|
Name: "list(string),type",
|
|
Type: "TensorSummary",
|
|
Input: []Input{
|
|
makeConst(""),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"T": String,
|
|
"labels": []string{"foo", "bar"},
|
|
},
|
|
},
|
|
{
|
|
Name: "list(string) empty,type",
|
|
Type: "TensorSummary",
|
|
Input: []Input{
|
|
makeConst(""),
|
|
},
|
|
Attrs: map[string]interface{}{
|
|
"T": String,
|
|
"labels": ([]string)(nil),
|
|
},
|
|
},
|
|
{
|
|
Name: "tensor",
|
|
Type: "Const",
|
|
Attrs: map[string]interface{}{
|
|
"dtype": String,
|
|
"value": makeTensor("foo"),
|
|
},
|
|
},
|
|
}
|
|
|
|
for i, spec := range cases {
|
|
op, err := g.AddOperation(spec)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for key, want := range spec.Attrs {
|
|
out, err := op.Attr(key)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !reflect.DeepEqual(out, want) {
|
|
t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want)
|
|
}
|
|
wantT, ok := want.(*Tensor)
|
|
if ok {
|
|
wantVal := wantT.Value()
|
|
outVal := out.(*Tensor).Value()
|
|
if !reflect.DeepEqual(outVal, wantVal) {
|
|
t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|