diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 6d884f32f83..df5b34cec89 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -215,25 +215,29 @@ func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) } func (t *Tensor) Shape() []int64 { return t.shape } // Reshape updates tensor's shape in place if this is possible or returns an error otherwise. -func (t *Tensor) Reshape(new_shape []int64) error { - old_shape_size := numElements(t.shape) - new_shape_size := numElements(new_shape) +func (t *Tensor) Reshape(newShape []int64) error { + oldShapeSize := numElements(t.shape) + newShapeSize := numElements(newShape) - if old_shape_size != new_shape_size { - return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size) + if oldShapeSize != newShapeSize { + return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, oldShapeSize, newShape, newShapeSize) } - if len(new_shape) == 0 { + if len(newShape) == 0 { return nil } var shapePtr *C.int64_t - shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0])) + shapePtr = (*C.int64_t)(unsafe.Pointer(&newShape[0])) status := newStatus() - C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c) + C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(newShape)), status.c) - return status.Err() + if err := status.Err(); err != nil { + return err + } + t.shape = newShape + return nil } // Value converts the Tensor to a Go value. For now, not all Tensor types are diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 15b2ea55ad8..8aa710669a0 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -358,3 +358,31 @@ func BenchmarkTensor(b *testing.B) { }) } + +func TestReshape(t *testing.T) { + tensor, err := NewTensor([]int64{1, 2}) + if err != nil { + t.Fatalf("Unable to create new tensor: %v", err) + } + + if got, want := len(tensor.Shape()), 1; got != want { + t.Fatalf("len(tensor.Shape()): got %d, want %d", got, want) + } + if got, want := tensor.Shape()[0], int64(2); got != want { + t.Errorf("tensor.Shape()[0]: got %d, want %d", got, want) + } + + if err := tensor.Reshape([]int64{1, 2}); err != nil { + t.Fatalf("tensor.Reshape([1, 2]) failed: %v", err) + } + + if got, want := len(tensor.Shape()), 2; got != want { + t.Fatalf("After reshape, len(tensor.Shape()): got %d, want %d", got, want) + } + if got, want := tensor.Shape()[0], int64(1); got != want { + t.Errorf("After reshape, tensor.Shape()[0]: got %d, want %d", got, want) + } + if got, want := tensor.Shape()[1], int64(2); got != want { + t.Errorf("After reshape, tensor.Shape()[1]: got %d, want %d", got, want) + } +}