Make tf.random_uniform([0], maxval=0, dtype=tf.int32) not crash
For integers, tf.random_uniform enforces a nonempty range with minval < maxval. However, an empty range is fine if we're producing no output values, and this degenerate case occurs naturally for some code patterns. Thus, tf.random_uniform now allows empty ranges for integer random numbers if the output shape is empty.
This commit is contained in:
parent
25c9913136
commit
97011c17de
@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
|
|||||||
errors::InvalidArgument("maxval must be 0-D, got shape ",
|
errors::InvalidArgument("maxval must be 0-D, got shape ",
|
||||||
maxval.shape().DebugString()));
|
maxval.shape().DebugString()));
|
||||||
|
|
||||||
// Verify that minval < maxval
|
// Allocate output, and exit early if possible
|
||||||
|
Tensor* output;
|
||||||
|
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
||||||
|
if (output->NumElements() == 0) return;
|
||||||
|
|
||||||
|
// Verify that minval < maxval. This check intentionally happens after the
|
||||||
|
// early exit for empty output. Zero impossible things are fine.
|
||||||
IntType lo = minval.scalar<IntType>()();
|
IntType lo = minval.scalar<IntType>()();
|
||||||
IntType hi = maxval.scalar<IntType>()();
|
IntType hi = maxval.scalar<IntType>()();
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
|
|||||||
Distribution;
|
Distribution;
|
||||||
Distribution dist(lo, hi);
|
Distribution dist(lo, hi);
|
||||||
|
|
||||||
Tensor* output;
|
|
||||||
OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
|
|
||||||
auto output_flat = output->flat<IntType>();
|
auto output_flat = output->flat<IntType>();
|
||||||
functor::FillPhiloxRandom<Device, Distribution>()(
|
functor::FillPhiloxRandom<Device, Distribution>()(
|
||||||
ctx, ctx->eigen_device<Device>(),
|
ctx, ctx->eigen_device<Device>(),
|
||||||
|
@ -320,6 +320,15 @@ class RandomUniformTest(RandomOpTestCommon):
|
|||||||
error = np.abs(counts - mean)
|
error = np.abs(counts - mean)
|
||||||
self.assertLess(error.max(), 5 * std)
|
self.assertLess(error.max(), 5 * std)
|
||||||
|
|
||||||
|
# Check that minval = maxval is fine iff we're producing no numbers
|
||||||
|
def testUniformIntsDegenerate(self):
|
||||||
|
for dt in dtypes.int32, dtypes.int64:
|
||||||
|
def sample(n):
|
||||||
|
return self._Sampler(n, minv=0, maxv=0, dtype=dt, use_gpu=True)()
|
||||||
|
self.assertEqual(sample(0).shape, (10, 0))
|
||||||
|
with self.assertRaisesOpError('Need minval < maxval, got 0 >= 0'):
|
||||||
|
sample(1)
|
||||||
|
|
||||||
# Checks that the CPU and GPU implementation returns the same results,
|
# Checks that the CPU and GPU implementation returns the same results,
|
||||||
# given the same random seed
|
# given the same random seed
|
||||||
def testCPUGPUMatch(self):
|
def testCPUGPUMatch(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user