Call uniform_full_int in tf.random.Generator.uniform when dtype is integer and both minval and maxval are None, to align with stateless_random_uniform.

PiperOrigin-RevId: 299455876
Change-Id: I7274a15bef63ea349631e2f3053c71af387c2ca1
This commit is contained in:
Peng Wang 2020-03-06 16:01:57 -08:00 committed by TensorFlower Gardener
parent 1e2122cf33
commit 3e95e9f03a
3 changed files with 38 additions and 13 deletions

View File

@ -666,6 +666,11 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
return gen_stateful_random_ops.stateful_uniform(
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
def _uniform_full_int(self, shape, dtype, name=None):
return gen_stateful_random_ops.stateful_uniform_full_int(
self.state.handle, self.algorithm, shape=shape,
dtype=dtype, name=name)
def uniform(self, shape, minval=0, maxval=None,
dtype=dtypes.float32, name=None):
"""Outputs random values from a uniform distribution.
@ -684,17 +689,22 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
`maxval - minval` significantly smaller than the range of the output (either
`2**32` or `2**64`).
For full-range random integers, pass `minval=None` and `maxval=None` with an
integer `dtype` (for integer dtypes, `minval` and `maxval` must be both
`None` or both not `None`).
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output
tensor.
minval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs
to be a scalar). The lower bound on the range of random values to
generate. Defaults to 0.
to be a scalar). The lower bound (included) on the range of random
values to generate. Pass `None` for full-range integers. Defaults to 0.
maxval: A Tensor or Python value of type `dtype`, broadcastable with
`shape` (for integer types, broadcasting is not supported, so it needs
to be a scalar). The upper bound on the range of random values to
generate. Defaults to 1 if `dtype` is floating point.
to be a scalar). The upper bound (excluded) on the range of random
values to generate. Pass `None` for full-range integers. Defaults to 1
if `dtype` is floating point.
dtype: The type of the output.
name: A name for the operation (optional).
@ -705,13 +715,18 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
ValueError: If `dtype` is integral and `maxval` is not specified.
"""
dtype = dtypes.as_dtype(dtype)
if maxval is None:
if dtype.is_integer:
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
if dtype.is_integer:
if (minval is None) != (maxval is None):
raise ValueError("For integer dtype {}, minval and maxval must be both "
"`None` or both non-`None`; got minval={} and "
"maxval={}".format(dtype, minval, maxval))
elif maxval is None:
maxval = 1
with ops.name_scope(name, "stateful_uniform",
[shape, minval, maxval]) as name:
shape = _shape_tensor(shape)
if dtype.is_integer and minval is None:
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
@ -725,8 +740,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
"""Uniform distribution on an integer type's entire range.
The other method `uniform` only covers the range [minval, maxval), which
cannot be `dtype`'s full range because `maxval` is of type `dtype`.
This method is the same as setting `minval` and `maxval` to `None` in the
`uniform` method.
Args:
shape: the shape of the output.
@ -740,9 +755,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
with ops.name_scope(name, "stateful_uniform_full_int",
[shape]) as name:
shape = _shape_tensor(shape)
return gen_stateful_random_ops.stateful_uniform_full_int(
self.state.handle, self.algorithm, shape=shape,
dtype=dtype, name=name)
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
"""Outputs random values from a binomial distribution.

View File

@ -748,6 +748,18 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(2, len(local_results))
self.assertAllDifferent(local_results)
@test_util.run_v2_only
def testUniformFullInt(self):
"""Tests full-range int uniform.
"""
shape = [3, 4]
dtype = dtypes.int32
g = random.Generator.from_seed(1)
r1 = g.uniform(shape=shape, dtype=dtype, minval=None)
g = random.Generator.from_seed(1)
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
self.assertAllEqual(r1, r2)
if __name__ == "__main__":
test.main()

View File

@ -66,7 +66,7 @@ def stateless_random_uniform(shape,
`maxval - minval` significantly smaller than the range of the output (either
`2**32` or `2**64`).
For full full-range (i.e. inclusive of both max and min) random integers, pass
For full-range (i.e. inclusive of both max and min) random integers, pass
`minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
either both `minval` and `maxval` must be `None` or neither may be `None`. For
example: