Fix bug in Tensor.Reshape.
The shape of a Tensor needs to be updated after a successful call to Reshape since the shape of the Tensor isn't fetched dynamically. PiperOrigin-RevId: 346375411 Change-Id: Icf44a10bf72f175be5a16c7cebf6c47066427f56
This commit is contained in:
parent
4e6fbd12cb
commit
b6d5fa8ab6
@ -215,25 +215,29 @@ func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) }
|
||||
func (t *Tensor) Shape() []int64 { return t.shape }
|
||||
|
||||
// Reshape updates tensor's shape in place if this is possible or returns an error otherwise.
|
||||
func (t *Tensor) Reshape(new_shape []int64) error {
|
||||
old_shape_size := numElements(t.shape)
|
||||
new_shape_size := numElements(new_shape)
|
||||
func (t *Tensor) Reshape(newShape []int64) error {
|
||||
oldShapeSize := numElements(t.shape)
|
||||
newShapeSize := numElements(newShape)
|
||||
|
||||
if old_shape_size != new_shape_size {
|
||||
return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, old_shape_size, new_shape, new_shape_size)
|
||||
if oldShapeSize != newShapeSize {
|
||||
return fmt.Errorf("unable to convert shape %v (num_elements: %d) into shape %v (num_elements: %d)", t.shape, oldShapeSize, newShape, newShapeSize)
|
||||
}
|
||||
|
||||
if len(new_shape) == 0 {
|
||||
if len(newShape) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var shapePtr *C.int64_t
|
||||
shapePtr = (*C.int64_t)(unsafe.Pointer(&new_shape[0]))
|
||||
shapePtr = (*C.int64_t)(unsafe.Pointer(&newShape[0]))
|
||||
|
||||
status := newStatus()
|
||||
C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(new_shape)), status.c)
|
||||
C.TF_TensorBitcastFrom(t.c, C.TF_TensorType(t.c), t.c, shapePtr, C.int(len(newShape)), status.c)
|
||||
|
||||
return status.Err()
|
||||
if err := status.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.shape = newShape
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value converts the Tensor to a Go value. For now, not all Tensor types are
|
||||
|
@ -358,3 +358,31 @@ func BenchmarkTensor(b *testing.B) {
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestReshape(t *testing.T) {
|
||||
tensor, err := NewTensor([]int64{1, 2})
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create new tensor: %v", err)
|
||||
}
|
||||
|
||||
if got, want := len(tensor.Shape()), 1; got != want {
|
||||
t.Fatalf("len(tensor.Shape()): got %d, want %d", got, want)
|
||||
}
|
||||
if got, want := tensor.Shape()[0], int64(2); got != want {
|
||||
t.Errorf("tensor.Shape()[0]: got %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if err := tensor.Reshape([]int64{1, 2}); err != nil {
|
||||
t.Fatalf("tensor.Reshape([1, 2]) failed: %v", err)
|
||||
}
|
||||
|
||||
if got, want := len(tensor.Shape()), 2; got != want {
|
||||
t.Fatalf("After reshape, len(tensor.Shape()): got %d, want %d", got, want)
|
||||
}
|
||||
if got, want := tensor.Shape()[0], int64(1); got != want {
|
||||
t.Errorf("After reshape, tensor.Shape()[0]: got %d, want %d", got, want)
|
||||
}
|
||||
if got, want := tensor.Shape()[1], int64(2); got != want {
|
||||
t.Errorf("After reshape, tensor.Shape()[1]: got %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user