diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index fa1b6b91710..6cf4c01f1b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel { TensorShape shape; int64 product = 1; int unknown_index = -1; + bool shape_has_zero_dim = false; for (int d = 0; d < num_dims; ++d) { const int32 size = shape_input[d]; if (size == -1) { @@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel { unknown_index, " and ", d)); unknown_index = d; shape.AddDim(1); + } else if (size == 0) { + // We don't include zero-sized dimension in product, so that we can + // still calculate number of elements for non-zero-sized dimensions and + // therefore infer their shapes. + shape.AddDim(size); + shape_has_zero_dim = true; } else { OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument( @@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel { } } if (unknown_index != -1) { - OP_REQUIRES( - ctx, product > 0, - errors::InvalidArgument("Reshape cannot infer the missing input size " - "for an empty tensor unless all specified " - "input sizes are non-zero")); - const int64 missing = input_shape.num_elements() / product; - OP_REQUIRES( - ctx, product * missing == input_shape.num_elements(), - errors::InvalidArgument( - "Input to reshape is a tensor with ", input_shape.num_elements(), - " values, but the requested shape requires a multiple of ", - product)); + int64 input_num_elements = 1; + bool input_has_zero_dim = false; + for (int dim = 0; dim < input_shape.dims(); dim++) { + // For zero dimension, we don't count it into `input_num_elements` + // unless `sizes` has no zero dimension, so we are still able to + // infer shapes for other dimensions. + if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) { + input_num_elements *= input_shape.dim_size(dim); + } else { + input_has_zero_dim = true; + } + } + + const int64 missing = input_num_elements / product; + if (!input_has_zero_dim) { + OP_REQUIRES( + ctx, product * missing == input_num_elements, + errors::InvalidArgument( + "Input to reshape is a tensor with ", input_num_elements, + " values, but the requested shape requires a multiple of ", + product)); + } shape.set_dim(unknown_index, missing); } OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(), diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 7458ac75ca0..47cd219d8cf 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -45,14 +45,17 @@ class ReshapeOp : public OpKernel { TensorShape shape; int64 product = 1; int unknown_index = -1; + bool sizes_has_zero_dim; switch (sizes.dtype()) { case DT_INT32: - OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, - &unknown_index, &shape)); + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); break; case DT_INT64: - OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, - &unknown_index, &shape)); + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); break; default: context->CtxFailure(errors::InvalidArgument( @@ -61,18 +64,28 @@ class ReshapeOp : public OpKernel { return; } if (unknown_index != -1) { - OP_REQUIRES( - context, product > 0, - errors::InvalidArgument("Reshape cannot infer the missing input size " - "for an empty tensor unless all specified " - "input sizes are non-zero")); - const int64 missing = input.NumElements() / product; - OP_REQUIRES( - context, product * missing == input.NumElements(), - errors::InvalidArgument( - "Input to reshape is a tensor with ", input.NumElements(), - " values, but the requested shape requires a multiple of ", - product)); + int64 input_num_elements = 1; + bool input_has_zero_dim = false; + for (int dim = 0; dim < input.dims(); dim++) { + // For zero dimension, we don't count it into `input_num_elements` + // unless `sizes` has no zero dimension, so we are still able to + // infer shapes for other dimensions. + if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) { + input_num_elements *= input.dim_size(dim); + } else { + input_has_zero_dim = true; + } + } + + const int64 missing = input_num_elements / product; + if (!input_has_zero_dim) { + OP_REQUIRES( + context, product * missing == input_num_elements, + errors::InvalidArgument( + "Input to reshape is a tensor with ", input_num_elements, + " values, but the requested shape requires a multiple of ", + product)); + } shape.set_dim(unknown_index, missing); } OP_REQUIRES(context, shape.num_elements() == input.NumElements(), @@ -92,9 +105,10 @@ class ReshapeOp : public OpKernel { private: template Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, - TensorShape* shape) { + TensorShape* shape, bool* has_zero_dim) { *product = 1; *unknown_index = -1; + *has_zero_dim = false; const int64 num_dims = sizes.NumElements(); auto Svec = sizes.flat(); for (int d = 0; d < num_dims; ++d) { @@ -110,6 +124,12 @@ class ReshapeOp : public OpKernel { } else if (size < 0) { return errors::InvalidArgument("Size ", d, " must be non-negative, not ", size); + } else if (size == 0) { + // We don't include zero-sized dimension in product, so that we can + // still calculate number of elements for non-zero-sized dimensions and + // therefore infer their shapes. + shape->AddDim(size); + *has_zero_dim = true; } else { shape->AddDim(size); (*product) *= size; diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py index db3e88a104f..264838d7ae5 100644 --- a/tensorflow/python/kernel_tests/reshape_op_test.py +++ b/tensorflow/python/kernel_tests/reshape_op_test.py @@ -45,6 +45,18 @@ class ReshapeTest(test.TestCase): self.assertEqual(tf_ans.get_shape(), out.shape) self.assertShapeEqual(np_ans, tf_ans) + def _testZeroDimReshape(self, x, shape, expected, use_gpu=False): + with self.cached_session(use_gpu=use_gpu): + y = array_ops.reshape(x, shape) + out = self.evaluate(y) + self.assertEqual(expected, out.shape) + + # Repeat with an int64 shape tensor. + shape64 = constant_op.constant(shape, dtype=dtypes.int64) + y = array_ops.reshape(x, shape64) + out = self.evaluate(y) + self.assertEqual(expected, out.shape) + def _testBothReshape(self, x, y): self._testReshape(x, y, False) self._testReshape(x, y, True) @@ -89,6 +101,18 @@ class ReshapeTest(test.TestCase): x = np.arange(1., 7.).reshape([6]).astype(np.float32) self._testBothReshape(x, [3, -1]) + def testZeroDimBasic(self): + x = np.zeros([0, 6]).astype(np.float32) + self._testBothReshape(x, [0, 2, 3]) + + def testZeroDimReshapeR1(self): + x = np.zeros([0, 6]).astype(np.float32) + self._testBothReshape(x, [-1]) + + def testZeroDimReshapeR3(self): + x = np.zeros([0, 6]).astype(np.float32) + self._testBothReshape(x, [-1, 2, 3]) + # TODO(vrv): Add tests for failure conditions once python test_util # reports errors. @@ -113,6 +137,13 @@ class ReshapeTest(test.TestCase): self._testBothReshape(x, [0, 0, 0]) self._testBothReshape(x, [1, -1, 5]) + def testZeroDimWithUnspecifiedDim(self): + for use_gpu in (True, False): + self._testZeroDimReshape(x=np.zeros([0, 6]).astype(np.float32), + shape=[0, -1, 3], + expected=(0, 2, 3), + use_gpu=use_gpu) + @test_util.run_deprecated_v1 def testErrors(self): y = constant_op.constant(0.0, shape=[23, 29, 31])