Go: Bugfix: Make list-of-shape attributes in an operation work.

By respecting cgo rules on pointers.
Without the change to graph.go, the newly added test would fail with:

panic: runtime error: cgo argument has Go pointer to Go pointer

in the call to the C function TF_SetAttrShapeList.

Fixes #14891

PiperOrigin-RevId: 177336663
This commit is contained in:
Asim Shankar 2017-11-29 11:18:38 -08:00 committed by TensorFlower Gardener
parent c572bc4fd7
commit 78a4873cfa
2 changed files with 112 additions and 25 deletions

View File

@ -20,6 +20,24 @@ package tensorflow
//
// #include <stdlib.h>
// #include <string.h>
//
// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc,
// const char* attr_name,
// const int64_t* flat_dims,
// const int* num_dims,
// int num_shapes) {
// const int64_t** dims =
// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes);
// for (int i = 0; i < num_shapes; i++) {
// dims[i] = flat_dims;
// if (num_dims[i] > 0) {
// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0.
// flat_dims += num_dims[i];
// }
// }
// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes);
// free(dims);
// }
import "C"
import (
@ -289,41 +307,37 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
return fmt.Errorf("bad value for attribute %q: %v", name, err)
}
case Shape:
ndims, dims := cshape(value)
ndims := C.int(value.NumDimensions())
var dimsp *C.int64_t
if ndims > 0 {
dims := make([]C.int64_t, ndims)
for i, d := range value.dims {
dims[i] = C.int64_t(d)
}
dimsp = &dims[0]
}
C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
case []Shape:
ndims := make([]C.int, len(value))
dims := make([][]C.int64_t, len(value))
dimsp := make([]*C.int64_t, len(value))
for i, s := range value {
ndims[i], dims[i] = cshape(s)
if ndims[i] > 0 {
dimsp[i] = &dims[i][0]
}
}
if len(value) > 0 {
C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
} else {
if len(value) == 0 {
C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0)
} else {
var flatDims []C.int64_t
ndims := make([]C.int, len(value))
for i, s := range value {
nd := s.NumDimensions()
ndims[i] = C.int(nd)
for _, d := range s.dims {
flatDims = append(flatDims, C.int64_t(d))
}
}
var flatDimsp *C.int64_t
if len(flatDims) > 0 {
flatDimsp = &flatDims[0]
}
C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value)))
}
default:
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
}
return nil
}
func cshape(s Shape) (C.int, []C.int64_t) {
ndims := C.int(s.NumDimensions())
if ndims < 0 {
return -1, nil
}
dims := make([]C.int64_t, ndims)
for i, s := range s.dims {
dims[i] = C.int64_t(s)
}
return ndims, dims
}

View File

@ -58,3 +58,76 @@ func TestAddOperationFailure(t *testing.T) {
_ = resize.Shape()
t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created")
}
func TestShapeAttribute(t *testing.T) {
s := NewScope()
x := Placeholder(s.SubScope("x"), tf.Int32, PlaceholderShape(tf.MakeShape(1)))
y := Placeholder(s.SubScope("y"), tf.Int32, PlaceholderShape(tf.Shape{}))
z := Add(s, x, y)
graph, err := s.Finalize()
if err != nil {
t.Fatal(err)
}
sess, err := tf.NewSession(graph, nil)
if err != nil {
t.Fatal(err)
}
value, err := tf.NewTensor([]int32{7})
if err != nil {
t.Fatal(err)
}
feeds := map[tf.Output]*tf.Tensor{
x: value,
y: value,
}
fetched, err := sess.Run(feeds, []tf.Output{z}, nil)
if err != nil {
t.Fatal(err)
}
if got, want := len(fetched), 1; got != want {
t.Fatalf("Fetched %d tensors, expected %d", got, want)
}
if got, want := fetched[0].Value().([]int32), []int32{14}; len(got) != len(want) || len(got) != 1 || got[0] != want[0] {
t.Fatalf("Got %v, want %v", got, want)
}
}
func TestDataset(t *testing.T) {
var (
s = NewScope()
// The use of a non-scalar here is inspired by
// https://github.com/tensorflow/tensorflow/issues/14891
c = Const(s, []int32{21718, 31415})
types = []tf.DataType{c.DataType()}
shapes = []tf.Shape{c.Shape()}
dataset = TensorDataset(s, []tf.Output{c}, shapes)
iterator = Iterator(s, "", "", types, shapes)
next = IteratorGetNext(s, iterator, types, shapes)
init = MakeIterator(s, dataset, iterator)
)
graph, err := s.Finalize()
if err != nil {
t.Fatal(err)
}
sess, err := tf.NewSession(graph, nil)
if err != nil {
t.Fatal(err)
}
if _, err := sess.Run(nil, nil, []*tf.Operation{init}); err != nil {
t.Fatal(err)
}
results, err := sess.Run(nil, next, nil)
if err != nil {
t.Fatal(err)
}
got := results[0].Value().([]int32)
if len(got) != 2 || got[0] != 21718 || got[1] != 31415 {
t.Errorf("Got %v, want {21718, 31415}", got)
}
if _, err := sess.Run(nil, next, nil); err == nil {
t.Errorf("Expected sess.Run() to fail since the iterator should have reached the end of the dataset")
}
}