Fix floating point exception with tf.unravel_index
This PR tries to address the issue raised in 40204 where `tf.unravel_index` caused floating point exception when any one element is 0 in `dims`. The issue is that `indices` in `tf.unravel_index` should make sure it is not out of boundary compared to `dims`. This PR fixes the issue by adding a check before hand, though Eigen. This PR fixes 40204. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
ae76544efc
commit
b0375e267c
@ -54,6 +54,24 @@ class UnravelIndexOp : public OpKernel {
|
|||||||
|
|
||||||
auto dims = dims_tensor.vec<Tidx>();
|
auto dims = dims_tensor.vec<Tidx>();
|
||||||
|
|
||||||
|
// Chek to make sure indices is not out of boundary
|
||||||
|
Eigen::Tensor<bool, 0, Eigen::RowMajor> check;
|
||||||
|
if (TensorShapeUtils::IsScalar(indices_tensor.shape())) {
|
||||||
|
auto indices = indices_tensor.scalar<Tidx>();
|
||||||
|
auto dims_prod = dims.prod();
|
||||||
|
check = (indices < dims_prod).all();
|
||||||
|
} else {
|
||||||
|
auto indices = indices_tensor.vec<Tidx>();
|
||||||
|
auto dims_prod = dims.prod()
|
||||||
|
.reshape(Eigen::array<Eigen::Index, 1>({1}))
|
||||||
|
.broadcast(
|
||||||
|
Eigen::array<Eigen::Index, 1>({indices_tensor.NumElements()}));
|
||||||
|
check = (indices < dims_prod).all();
|
||||||
|
}
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, check(),
|
||||||
|
errors::InvalidArgument("index is out of bound as with dims"));
|
||||||
|
|
||||||
Eigen::array<bool, 1> reverse({true});
|
Eigen::array<bool, 1> reverse({true});
|
||||||
|
|
||||||
Tensor strides_tensor;
|
Tensor strides_tensor;
|
||||||
|
@ -1384,6 +1384,16 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
|
|||||||
out_3 = array_ops.unravel_index(indices_3, dims_3)
|
out_3 = array_ops.unravel_index(indices_3, dims_3)
|
||||||
self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
|
self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
|
||||||
|
|
||||||
|
# Test case for GitHub issue 40204.
|
||||||
|
def testUnravelIndexZeroDim(self):
|
||||||
|
with self.cached_session():
|
||||||
|
for dtype in [dtypes.int32, dtypes.int64]:
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors.InvalidArgumentError, "index is out of bound as with dims"):
|
||||||
|
indices = constant_op.constant([2, 5, 7], dtype=dtype)
|
||||||
|
dims = constant_op.constant([3, 0], dtype=dtype)
|
||||||
|
self.evaluate(array_ops.unravel_index(indices=indices, dims=dims))
|
||||||
|
|
||||||
|
|
||||||
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
|
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user