diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc index 8d839ba85a7..d41915e5c14 100644 --- a/tensorflow/core/kernels/unravel_index_op.cc +++ b/tensorflow/core/kernels/unravel_index_op.cc @@ -54,6 +54,24 @@ class UnravelIndexOp : public OpKernel { auto dims = dims_tensor.vec(); + // Chek to make sure indices is not out of boundary + Eigen::Tensor check; + if (TensorShapeUtils::IsScalar(indices_tensor.shape())) { + auto indices = indices_tensor.scalar(); + auto dims_prod = dims.prod(); + check = (indices < dims_prod).all(); + } else { + auto indices = indices_tensor.vec(); + auto dims_prod = dims.prod() + .reshape(Eigen::array({1})) + .broadcast( + Eigen::array({indices_tensor.NumElements()})); + check = (indices < dims_prod).all(); + } + OP_REQUIRES( + ctx, check(), + errors::InvalidArgument("index is out of bound as with dims")); + Eigen::array reverse({true}); Tensor strides_tensor; diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index ec3ed932996..caf05042557 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1384,6 +1384,16 @@ class UnravelIndexTest(test_util.TensorFlowTestCase): out_3 = array_ops.unravel_index(indices_3, dims_3) self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) + # Test case for GitHub issue 40204. + def testUnravelIndexZeroDim(self): + with self.cached_session(): + for dtype in [dtypes.int32, dtypes.int64]: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, "index is out of bound as with dims"): + indices = constant_op.constant([2, 5, 7], dtype=dtype) + dims = constant_op.constant([3, 0], dtype=dtype) + self.evaluate(array_ops.unravel_index(indices=indices, dims=dims)) + class GuaranteeConstOpTest(test_util.TensorFlowTestCase):