Fix bug in Tensor.Reshape.

The shape of a Tensor needs to be updated after a successful call to Reshape since the shape of the Tensor isn't fetched dynamically.

PiperOrigin-RevId: 346375411
Change-Id: Icf44a10bf72f175be5a16c7cebf6c47066427f56
This commit is contained in:
Phil Stahlfeld 2020-12-08 11:47:19 -08:00 committed by TensorFlower Gardener
parent 4e6fbd12cb
commit b6d5fa8ab6
2 changed files with 41 additions and 9 deletions

View File

@ -215,25 +215,29 @@ func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) }
func (t *Tensor) Shape() []int64 { return t.shape }
// Reshape updates tensor's shape in place if this is possible or returns an error otherwise.
func (t *Tensor) Reshape(new_shape []int64) error {
old_shape_size := numElements(t.shape)
new_shape_size := numElements(new_shape)
func (t *Tensor) Reshape(newShape []int64) error {
oldShapeSize := numElements(t.shape)
newShapeSize := numElements(newShape)
if old_shape_size != new_shape_size {
return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size)
if oldShapeSize != newShapeSize {
return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, oldShapeSize, newShape, newShapeSize)
}
if len(new_shape) == 0 {
if len(newShape) == 0 {
return nil
}
var shapePtr *C.int64_t
shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0]))
shapePtr = (*C.int64_t)(unsafe.Pointer(&newShape[0]))
status := newStatus()
C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c)
C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(newShape)), status.c)
return status.Err()
if err := status.Err(); err != nil {
return err
}
t.shape = newShape
return nil
}
// Value converts the Tensor to a Go value. For now, not all Tensor types are

View File

@ -358,3 +358,31 @@ func BenchmarkTensor(b *testing.B) {
})
}
func TestReshape(t *testing.T) {
tensor, err := NewTensor([]int64{1, 2})
if err != nil {
t.Fatalf("Unable to create new tensor: %v", err)
}
if got, want := len(tensor.Shape()), 1; got != want {
t.Fatalf("len(tensor.Shape()): got %d, want %d", got, want)
}
if got, want := tensor.Shape()[0], int64(2); got != want {
t.Errorf("tensor.Shape()[0]: got %d, want %d", got, want)
}
if err := tensor.Reshape([]int64{1, 2}); err != nil {
t.Fatalf("tensor.Reshape([1, 2]) failed: %v", err)
}
if got, want := len(tensor.Shape()), 2; got != want {
t.Fatalf("After reshape, len(tensor.Shape()): got %d, want %d", got, want)
}
if got, want := tensor.Shape()[0], int64(1); got != want {
t.Errorf("After reshape, tensor.Shape()[0]: got %d, want %d", got, want)
}
if got, want := tensor.Shape()[1], int64(2); got != want {
t.Errorf("After reshape, tensor.Shape()[1]: got %d, want %d", got, want)
}
}