Go: Bugfix: Make list-of-shape attributes in an operation work.
By respecting cgo rules on pointers. Without the change to graph.go, the newly added test would fail with: panic: runtime error: cgo argument has Go pointer to Go pointer in the call to the C function TF_SetAttrShapeList. Fixes #14891 PiperOrigin-RevId: 177336663
This commit is contained in:
parent
c572bc4fd7
commit
78a4873cfa
@ -20,6 +20,24 @@ package tensorflow
|
||||
//
|
||||
// #include <stdlib.h>
|
||||
// #include <string.h>
|
||||
//
|
||||
// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc,
|
||||
// const char* attr_name,
|
||||
// const int64_t* flat_dims,
|
||||
// const int* num_dims,
|
||||
// int num_shapes) {
|
||||
// const int64_t** dims =
|
||||
// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes);
|
||||
// for (int i = 0; i < num_shapes; i++) {
|
||||
// dims[i] = flat_dims;
|
||||
// if (num_dims[i] > 0) {
|
||||
// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0.
|
||||
// flat_dims += num_dims[i];
|
||||
// }
|
||||
// }
|
||||
// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes);
|
||||
// free(dims);
|
||||
// }
|
||||
import "C"
|
||||
|
||||
import (
|
||||
@ -289,41 +307,37 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
|
||||
return fmt.Errorf("bad value for attribute %q: %v", name, err)
|
||||
}
|
||||
case Shape:
|
||||
ndims, dims := cshape(value)
|
||||
ndims := C.int(value.NumDimensions())
|
||||
var dimsp *C.int64_t
|
||||
if ndims > 0 {
|
||||
dims := make([]C.int64_t, ndims)
|
||||
for i, d := range value.dims {
|
||||
dims[i] = C.int64_t(d)
|
||||
}
|
||||
dimsp = &dims[0]
|
||||
}
|
||||
C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
|
||||
case []Shape:
|
||||
ndims := make([]C.int, len(value))
|
||||
dims := make([][]C.int64_t, len(value))
|
||||
dimsp := make([]*C.int64_t, len(value))
|
||||
for i, s := range value {
|
||||
ndims[i], dims[i] = cshape(s)
|
||||
if ndims[i] > 0 {
|
||||
dimsp[i] = &dims[i][0]
|
||||
}
|
||||
}
|
||||
if len(value) > 0 {
|
||||
C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
|
||||
} else {
|
||||
if len(value) == 0 {
|
||||
C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0)
|
||||
} else {
|
||||
var flatDims []C.int64_t
|
||||
ndims := make([]C.int, len(value))
|
||||
for i, s := range value {
|
||||
nd := s.NumDimensions()
|
||||
ndims[i] = C.int(nd)
|
||||
for _, d := range s.dims {
|
||||
flatDims = append(flatDims, C.int64_t(d))
|
||||
}
|
||||
}
|
||||
var flatDimsp *C.int64_t
|
||||
if len(flatDims) > 0 {
|
||||
flatDimsp = &flatDims[0]
|
||||
}
|
||||
C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value)))
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cshape(s Shape) (C.int, []C.int64_t) {
|
||||
ndims := C.int(s.NumDimensions())
|
||||
if ndims < 0 {
|
||||
return -1, nil
|
||||
}
|
||||
dims := make([]C.int64_t, ndims)
|
||||
for i, s := range s.dims {
|
||||
dims[i] = C.int64_t(s)
|
||||
}
|
||||
return ndims, dims
|
||||
}
|
||||
|
@ -58,3 +58,76 @@ func TestAddOperationFailure(t *testing.T) {
|
||||
_ = resize.Shape()
|
||||
t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created")
|
||||
}
|
||||
|
||||
func TestShapeAttribute(t *testing.T) {
|
||||
s := NewScope()
|
||||
x := Placeholder(s.SubScope("x"), tf.Int32, PlaceholderShape(tf.MakeShape(1)))
|
||||
y := Placeholder(s.SubScope("y"), tf.Int32, PlaceholderShape(tf.Shape{}))
|
||||
z := Add(s, x, y)
|
||||
graph, err := s.Finalize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sess, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
value, err := tf.NewTensor([]int32{7})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
feeds := map[tf.Output]*tf.Tensor{
|
||||
x: value,
|
||||
y: value,
|
||||
}
|
||||
fetched, err := sess.Run(feeds, []tf.Output{z}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := len(fetched), 1; got != want {
|
||||
t.Fatalf("Fetched %d tensors, expected %d", got, want)
|
||||
}
|
||||
if got, want := fetched[0].Value().([]int32), []int32{14}; len(got) != len(want) || len(got) != 1 || got[0] != want[0] {
|
||||
t.Fatalf("Got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset(t *testing.T) {
|
||||
var (
|
||||
s = NewScope()
|
||||
|
||||
// The use of a non-scalar here is inspired by
|
||||
// https://github.com/tensorflow/tensorflow/issues/14891
|
||||
c = Const(s, []int32{21718, 31415})
|
||||
types = []tf.DataType{c.DataType()}
|
||||
shapes = []tf.Shape{c.Shape()}
|
||||
dataset = TensorDataset(s, []tf.Output{c}, shapes)
|
||||
|
||||
iterator = Iterator(s, "", "", types, shapes)
|
||||
next = IteratorGetNext(s, iterator, types, shapes)
|
||||
init = MakeIterator(s, dataset, iterator)
|
||||
)
|
||||
graph, err := s.Finalize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sess, err := tf.NewSession(graph, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := sess.Run(nil, nil, []*tf.Operation{init}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
results, err := sess.Run(nil, next, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := results[0].Value().([]int32)
|
||||
if len(got) != 2 || got[0] != 21718 || got[1] != 31415 {
|
||||
t.Errorf("Got %v, want {21718, 31415}", got)
|
||||
}
|
||||
if _, err := sess.Run(nil, next, nil); err == nil {
|
||||
t.Errorf("Expected sess.Run() to fail since the iterator should have reached the end of the dataset")
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user