Update the boundary check in unravel_index to use std::all_of instead of eigen
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
b0375e267c
commit
4393031240
@ -55,21 +55,15 @@ class UnravelIndexOp : public OpKernel {
|
||||
auto dims = dims_tensor.vec<Tidx>();
|
||||
|
||||
// Chek to make sure indices is not out of boundary
|
||||
Eigen::Tensor<bool, 0, Eigen::RowMajor> check;
|
||||
if (TensorShapeUtils::IsScalar(indices_tensor.shape())) {
|
||||
auto indices = indices_tensor.scalar<Tidx>();
|
||||
auto dims_prod = dims.prod();
|
||||
check = (indices < dims_prod).all();
|
||||
} else {
|
||||
auto indices = indices_tensor.vec<Tidx>();
|
||||
auto dims_prod = dims.prod()
|
||||
.reshape(Eigen::array<Eigen::Index, 1>({1}))
|
||||
.broadcast(
|
||||
Eigen::array<Eigen::Index, 1>({indices_tensor.NumElements()}));
|
||||
check = (indices < dims_prod).all();
|
||||
}
|
||||
Eigen::Tensor<Tidx, 0, Eigen::RowMajor> dims_prod_eigen = dims.prod();
|
||||
Tidx dims_prod = dims_prod_eigen();
|
||||
const Tidx* indices = indices_tensor.flat<Tidx>().data();
|
||||
int64 size = indices_tensor.NumElements();
|
||||
bool check = std::all_of(indices, indices + size, [&](Tidx index) {
|
||||
return index < dims_prod;
|
||||
});
|
||||
OP_REQUIRES(
|
||||
ctx, check(),
|
||||
ctx, check,
|
||||
errors::InvalidArgument("index is out of bound as with dims"));
|
||||
|
||||
Eigen::array<bool, 1> reverse({true});
|
||||
|
Loading…
Reference in New Issue
Block a user