Merge pull request from yongtang:40204-tf.unravel_index

PiperOrigin-RevId: 315547243
Change-Id: I7d6a2f54575dc5923a9ee66d013e907acdb308e6
This commit is contained in:
TensorFlower Gardener 2020-06-09 13:32:49 -07:00
commit 59eeafcc29
2 changed files with 20 additions and 0 deletions
tensorflow
core/kernels
python/kernel_tests

View File

@ -54,6 +54,16 @@ class UnravelIndexOp : public OpKernel {
auto dims = dims_tensor.vec<Tidx>();
// Chek to make sure indices is not out of boundary
Eigen::Tensor<Tidx, 0, Eigen::RowMajor> dims_prod_eigen = dims.prod();
Tidx dims_prod = dims_prod_eigen();
const Tidx* indices = indices_tensor.flat<Tidx>().data();
int64 size = indices_tensor.NumElements();
bool check = std::all_of(indices, indices + size,
[&](Tidx index) { return index < dims_prod; });
OP_REQUIRES(ctx, check,
errors::InvalidArgument("index is out of bound as with dims"));
Eigen::array<bool, 1> reverse({true});
Tensor strides_tensor;

View File

@ -1436,6 +1436,16 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
out_3 = array_ops.unravel_index(indices_3, dims_3)
self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
# Test case for GitHub issue 40204.
def testUnravelIndexZeroDim(self):
with self.cached_session():
for dtype in [dtypes.int32, dtypes.int64]:
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"index is out of bound as with dims"):
indices = constant_op.constant([2, 5, 7], dtype=dtype)
dims = constant_op.constant([3, 0], dtype=dtype)
self.evaluate(array_ops.unravel_index(indices=indices, dims=dims))
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):