Update the boundary check in unravel_index to use std::all_of instead of eigen

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-06-07 17:20:27 +00:00
parent b0375e267c
commit 4393031240

View File

@ -55,21 +55,15 @@ class UnravelIndexOp : public OpKernel {
auto dims = dims_tensor.vec<Tidx>();
// Chek to make sure indices is not out of boundary
Eigen::Tensor<bool, 0, Eigen::RowMajor> check;
if (TensorShapeUtils::IsScalar(indices_tensor.shape())) {
auto indices = indices_tensor.scalar<Tidx>();
auto dims_prod = dims.prod();
check = (indices < dims_prod).all();
} else {
auto indices = indices_tensor.vec<Tidx>();
auto dims_prod = dims.prod()
.reshape(Eigen::array<Eigen::Index, 1>({1}))
.broadcast(
Eigen::array<Eigen::Index, 1>({indices_tensor.NumElements()}));
check = (indices < dims_prod).all();
}
Eigen::Tensor<Tidx, 0, Eigen::RowMajor> dims_prod_eigen = dims.prod();
Tidx dims_prod = dims_prod_eigen();
const Tidx* indices = indices_tensor.flat<Tidx>().data();
int64 size = indices_tensor.NumElements();
bool check = std::all_of(indices, indices + size, [&](Tidx index) {
return index < dims_prod;
});
OP_REQUIRES(
ctx, check(),
ctx, check,
errors::InvalidArgument("index is out of bound as with dims"));
Eigen::array<bool, 1> reverse({true});