Make tf.reshape(tensor, shape) work with the case that tensor has 0 dimension and shape has unknown dimension.

PiperOrigin-RevId: 243094911
This commit is contained in:
Ruoxin Sang 2019-04-11 10:27:40 -07:00 committed by TensorFlower Gardener
parent 86a6d44352
commit 36e10587b5
3 changed files with 97 additions and 29 deletions

View File

@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
TensorShape shape; TensorShape shape;
int64 product = 1; int64 product = 1;
int unknown_index = -1; int unknown_index = -1;
bool shape_has_zero_dim = false;
for (int d = 0; d < num_dims; ++d) { for (int d = 0; d < num_dims; ++d) {
const int32 size = shape_input[d]; const int32 size = shape_input[d];
if (size == -1) { if (size == -1) {
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
unknown_index, " and ", d)); unknown_index, " and ", d));
unknown_index = d; unknown_index = d;
shape.AddDim(1); shape.AddDim(1);
} else if (size == 0) {
// We don't include zero-sized dimension in product, so that we can
// still calculate number of elements for non-zero-sized dimensions and
// therefore infer their shapes.
shape.AddDim(size);
shape_has_zero_dim = true;
} else { } else {
OP_REQUIRES(ctx, size >= 0, OP_REQUIRES(ctx, size >= 0,
errors::InvalidArgument( errors::InvalidArgument(
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
} }
} }
if (unknown_index != -1) { if (unknown_index != -1) {
int64 input_num_elements = 1;
bool input_has_zero_dim = false;
for (int dim = 0; dim < input_shape.dims(); dim++) {
// For zero dimension, we don't count it into `input_num_elements`
// unless `sizes` has no zero dimension, so we are still able to
// infer shapes for other dimensions.
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
input_num_elements *= input_shape.dim_size(dim);
} else {
input_has_zero_dim = true;
}
}
const int64 missing = input_num_elements / product;
if (!input_has_zero_dim) {
OP_REQUIRES( OP_REQUIRES(
ctx, product > 0, ctx, product * missing == input_num_elements,
errors::InvalidArgument("Reshape cannot infer the missing input size "
"for an empty tensor unless all specified "
"input sizes are non-zero"));
const int64 missing = input_shape.num_elements() / product;
OP_REQUIRES(
ctx, product * missing == input_shape.num_elements(),
errors::InvalidArgument( errors::InvalidArgument(
"Input to reshape is a tensor with ", input_shape.num_elements(), "Input to reshape is a tensor with ", input_num_elements,
" values, but the requested shape requires a multiple of ", " values, but the requested shape requires a multiple of ",
product)); product));
}
shape.set_dim(unknown_index, missing); shape.set_dim(unknown_index, missing);
} }
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),

View File

@ -45,14 +45,17 @@ class ReshapeOp : public OpKernel {
TensorShape shape; TensorShape shape;
int64 product = 1; int64 product = 1;
int unknown_index = -1; int unknown_index = -1;
bool sizes_has_zero_dim;
switch (sizes.dtype()) { switch (sizes.dtype()) {
case DT_INT32: case DT_INT32:
OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product, OP_REQUIRES_OK(context,
&unknown_index, &shape)); ValidateSizes<int32>(sizes, &product, &unknown_index,
&shape, &sizes_has_zero_dim));
break; break;
case DT_INT64: case DT_INT64:
OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product, OP_REQUIRES_OK(context,
&unknown_index, &shape)); ValidateSizes<int64>(sizes, &product, &unknown_index,
&shape, &sizes_has_zero_dim));
break; break;
default: default:
context->CtxFailure(errors::InvalidArgument( context->CtxFailure(errors::InvalidArgument(
@ -61,18 +64,28 @@ class ReshapeOp : public OpKernel {
return; return;
} }
if (unknown_index != -1) { if (unknown_index != -1) {
int64 input_num_elements = 1;
bool input_has_zero_dim = false;
for (int dim = 0; dim < input.dims(); dim++) {
// For zero dimension, we don't count it into `input_num_elements`
// unless `sizes` has no zero dimension, so we are still able to
// infer shapes for other dimensions.
if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) {
input_num_elements *= input.dim_size(dim);
} else {
input_has_zero_dim = true;
}
}
const int64 missing = input_num_elements / product;
if (!input_has_zero_dim) {
OP_REQUIRES( OP_REQUIRES(
context, product > 0, context, product * missing == input_num_elements,
errors::InvalidArgument("Reshape cannot infer the missing input size "
"for an empty tensor unless all specified "
"input sizes are non-zero"));
const int64 missing = input.NumElements() / product;
OP_REQUIRES(
context, product * missing == input.NumElements(),
errors::InvalidArgument( errors::InvalidArgument(
"Input to reshape is a tensor with ", input.NumElements(), "Input to reshape is a tensor with ", input_num_elements,
" values, but the requested shape requires a multiple of ", " values, but the requested shape requires a multiple of ",
product)); product));
}
shape.set_dim(unknown_index, missing); shape.set_dim(unknown_index, missing);
} }
OP_REQUIRES(context, shape.num_elements() == input.NumElements(), OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
@ -92,9 +105,10 @@ class ReshapeOp : public OpKernel {
private: private:
template <typename Tshape> template <typename Tshape>
Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
TensorShape* shape) { TensorShape* shape, bool* has_zero_dim) {
*product = 1; *product = 1;
*unknown_index = -1; *unknown_index = -1;
*has_zero_dim = false;
const int64 num_dims = sizes.NumElements(); const int64 num_dims = sizes.NumElements();
auto Svec = sizes.flat<Tshape>(); auto Svec = sizes.flat<Tshape>();
for (int d = 0; d < num_dims; ++d) { for (int d = 0; d < num_dims; ++d) {
@ -110,6 +124,12 @@ class ReshapeOp : public OpKernel {
} else if (size < 0) { } else if (size < 0) {
return errors::InvalidArgument("Size ", d, return errors::InvalidArgument("Size ", d,
" must be non-negative, not ", size); " must be non-negative, not ", size);
} else if (size == 0) {
// We don't include zero-sized dimension in product, so that we can
// still calculate number of elements for non-zero-sized dimensions and
// therefore infer their shapes.
shape->AddDim(size);
*has_zero_dim = true;
} else { } else {
shape->AddDim(size); shape->AddDim(size);
(*product) *= size; (*product) *= size;

View File

@ -45,6 +45,18 @@ class ReshapeTest(test.TestCase):
self.assertEqual(tf_ans.get_shape(), out.shape) self.assertEqual(tf_ans.get_shape(), out.shape)
self.assertShapeEqual(np_ans, tf_ans) self.assertShapeEqual(np_ans, tf_ans)
def _testZeroDimReshape(self, x, shape, expected, use_gpu=False):
with self.cached_session(use_gpu=use_gpu):
y = array_ops.reshape(x, shape)
out = self.evaluate(y)
self.assertEqual(expected, out.shape)
# Repeat with an int64 shape tensor.
shape64 = constant_op.constant(shape, dtype=dtypes.int64)
y = array_ops.reshape(x, shape64)
out = self.evaluate(y)
self.assertEqual(expected, out.shape)
def _testBothReshape(self, x, y): def _testBothReshape(self, x, y):
self._testReshape(x, y, False) self._testReshape(x, y, False)
self._testReshape(x, y, True) self._testReshape(x, y, True)
@ -89,6 +101,18 @@ class ReshapeTest(test.TestCase):
x = np.arange(1., 7.).reshape([6]).astype(np.float32) x = np.arange(1., 7.).reshape([6]).astype(np.float32)
self._testBothReshape(x, [3, -1]) self._testBothReshape(x, [3, -1])
def testZeroDimBasic(self):
x = np.zeros([0, 6]).astype(np.float32)
self._testBothReshape(x, [0, 2, 3])
def testZeroDimReshapeR1(self):
x = np.zeros([0, 6]).astype(np.float32)
self._testBothReshape(x, [-1])
def testZeroDimReshapeR3(self):
x = np.zeros([0, 6]).astype(np.float32)
self._testBothReshape(x, [-1, 2, 3])
# TODO(vrv): Add tests for failure conditions once python test_util # TODO(vrv): Add tests for failure conditions once python test_util
# reports errors. # reports errors.
@ -113,6 +137,13 @@ class ReshapeTest(test.TestCase):
self._testBothReshape(x, [0, 0, 0]) self._testBothReshape(x, [0, 0, 0])
self._testBothReshape(x, [1, -1, 5]) self._testBothReshape(x, [1, -1, 5])
def testZeroDimWithUnspecifiedDim(self):
for use_gpu in (True, False):
self._testZeroDimReshape(x=np.zeros([0, 6]).astype(np.float32),
shape=[0, -1, 3],
expected=(0, 2, 3),
use_gpu=use_gpu)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testErrors(self): def testErrors(self):
y = constant_op.constant(0.0, shape=[23, 29, 31]) y = constant_op.constant(0.0, shape=[23, 29, 31])