Make tf.random_uniform([0], maxval=0, dtype=tf.int32) not crash

For integers, tf.random_uniform enforces a nonempty range with minval < maxval.
However, an empty range is fine if we're producing no output values, and
this degenerate case occurs naturally for some code patterns.

Thus, tf.random_uniform now allows empty ranges for integer random
numbers if the output shape is empty.
This commit is contained in:
Geoffrey Irving 2018-09-07 09:01:56 -07:00
parent 25c9913136
commit 97011c17de
2 changed files with 16 additions and 3 deletions
tensorflow
core/kernels
python/kernel_tests/random

View File

@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
// Verify that minval < maxval
// Allocate output, and exit early if possible
Tensor* output;
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
if (output->NumElements() == 0) return;
// Verify that minval < maxval. This check intentionally happens after the
// early exit for empty output. Zero impossible things are fine.
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
Distribution;
Distribution dist(lo, hi);
Tensor* output;
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),

View File

@ -320,6 +320,15 @@ class RandomUniformTest(RandomOpTestCommon):
error = np.abs(counts - mean)
self.assertLess(error.max(), 5 * std)
# Check that minval = maxval is fine iff we're producing no numbers
def testUniformIntsDegenerate(self):
for dt in dtypes.int32, dtypes.int64:
def sample(n):
return self._Sampler(n, minv=0, maxv=0, dtype=dt, use_gpu=True)()
self.assertEqual(sample(0).shape, (10, 0))
with self.assertRaisesOpError('Need minval < maxval, got 0 >= 0'):
sample(1)
# Checks that the CPU and GPU implementation returns the same results,
# given the same random seed
def testCPUGPUMatch(self):