diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc index 8d839ba85a7..b45ff5e5b85 100644 --- a/tensorflow/core/kernels/unravel_index_op.cc +++ b/tensorflow/core/kernels/unravel_index_op.cc @@ -54,6 +54,16 @@ class UnravelIndexOp : public OpKernel { auto dims = dims_tensor.vec(); + // Chek to make sure indices is not out of boundary + Eigen::Tensor dims_prod_eigen = dims.prod(); + Tidx dims_prod = dims_prod_eigen(); + const Tidx* indices = indices_tensor.flat().data(); + int64 size = indices_tensor.NumElements(); + bool check = std::all_of(indices, indices + size, + [&](Tidx index) { return index < dims_prod; }); + OP_REQUIRES(ctx, check, + errors::InvalidArgument("index is out of bound as with dims")); + Eigen::array reverse({true}); Tensor strides_tensor; diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 9eb8bfcef41..dbff3a1b2f7 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1436,6 +1436,16 @@ class UnravelIndexTest(test_util.TensorFlowTestCase): out_3 = array_ops.unravel_index(indices_3, dims_3) self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) + # Test case for GitHub issue 40204. + def testUnravelIndexZeroDim(self): + with self.cached_session(): + for dtype in [dtypes.int32, dtypes.int64]: + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "index is out of bound as with dims"): + indices = constant_op.constant([2, 5, 7], dtype=dtype) + dims = constant_op.constant([3, 0], dtype=dtype) + self.evaluate(array_ops.unravel_index(indices=indices, dims=dims)) + class GuaranteeConstOpTest(test_util.TensorFlowTestCase):