Make tf.random_uniform([0], maxval=0, dtype=tf.int32) not crash
For integers, tf.random_uniform enforces a nonempty range with minval < maxval. However, an empty range is fine if we're producing no output values, and this degenerate case occurs naturally for some code patterns. Thus, tf.random_uniform now allows empty ranges for integer random numbers if the output shape is empty.
This commit is contained in:
parent
25c9913136
commit
97011c17de
tensorflow
@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
|
||||
errors::InvalidArgument("maxval must be 0-D, got shape ",
|
||||
maxval.shape().DebugString()));
|
||||
|
||||
// Verify that minval < maxval
|
||||
// Allocate output, and exit early if possible
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
||||
if (output->NumElements() == 0) return;
|
||||
|
||||
// Verify that minval < maxval. This check intentionally happens after the
|
||||
// early exit for empty output. Zero impossible things are fine.
|
||||
IntType lo = minval.scalar<IntType>()();
|
||||
IntType hi = maxval.scalar<IntType>()();
|
||||
OP_REQUIRES(
|
||||
@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
|
||||
Distribution;
|
||||
Distribution dist(lo, hi);
|
||||
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
||||
auto output_flat = output->flat<IntType>();
|
||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||
ctx, ctx->eigen_device<Device>(),
|
||||
|
@ -320,6 +320,15 @@ class RandomUniformTest(RandomOpTestCommon):
|
||||
error = np.abs(counts - mean)
|
||||
self.assertLess(error.max(), 5 * std)
|
||||
|
||||
# Check that minval = maxval is fine iff we're producing no numbers
|
||||
def testUniformIntsDegenerate(self):
|
||||
for dt in dtypes.int32, dtypes.int64:
|
||||
def sample(n):
|
||||
return self._Sampler(n, minv=0, maxv=0, dtype=dt, use_gpu=True)()
|
||||
self.assertEqual(sample(0).shape, (10, 0))
|
||||
with self.assertRaisesOpError('Need minval < maxval, got 0 >= 0'):
|
||||
sample(1)
|
||||
|
||||
# Checks that the CPU and GPU implementation returns the same results,
|
||||
# given the same random seed
|
||||
def testCPUGPUMatch(self):
|
||||
|
Loading…
Reference in New Issue
Block a user