Call uniform_full_int in tf.random.Generator.uniform when dtype is integer and both minval and maxval are None, to align with stateless_random_uniform.
PiperOrigin-RevId: 299455876 Change-Id: I7274a15bef63ea349631e2f3053c71af387c2ca1
This commit is contained in:
parent
1e2122cf33
commit
3e95e9f03a
@ -666,6 +666,11 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
return gen_stateful_random_ops.stateful_uniform(
|
||||
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
|
||||
|
||||
def _uniform_full_int(self, shape, dtype, name=None):
|
||||
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
dtype=dtype, name=name)
|
||||
|
||||
def uniform(self, shape, minval=0, maxval=None,
|
||||
dtype=dtypes.float32, name=None):
|
||||
"""Outputs random values from a uniform distribution.
|
||||
@ -684,17 +689,22 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
`maxval - minval` significantly smaller than the range of the output (either
|
||||
`2**32` or `2**64`).
|
||||
|
||||
For full-range random integers, pass `minval=None` and `maxval=None` with an
|
||||
integer `dtype` (for integer dtypes, `minval` and `maxval` must be both
|
||||
`None` or both not `None`).
|
||||
|
||||
Args:
|
||||
shape: A 1-D integer Tensor or Python array. The shape of the output
|
||||
tensor.
|
||||
minval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs
|
||||
to be a scalar). The lower bound on the range of random values to
|
||||
generate. Defaults to 0.
|
||||
to be a scalar). The lower bound (included) on the range of random
|
||||
values to generate. Pass `None` for full-range integers. Defaults to 0.
|
||||
maxval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||
`shape` (for integer types, broadcasting is not supported, so it needs
|
||||
to be a scalar). The upper bound on the range of random values to
|
||||
generate. Defaults to 1 if `dtype` is floating point.
|
||||
to be a scalar). The upper bound (excluded) on the range of random
|
||||
values to generate. Pass `None` for full-range integers. Defaults to 1
|
||||
if `dtype` is floating point.
|
||||
dtype: The type of the output.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
@ -705,13 +715,18 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
ValueError: If `dtype` is integral and `maxval` is not specified.
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if maxval is None:
|
||||
if dtype.is_integer:
|
||||
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
|
||||
if dtype.is_integer:
|
||||
if (minval is None) != (maxval is None):
|
||||
raise ValueError("For integer dtype {}, minval and maxval must be both "
|
||||
"`None` or both non-`None`; got minval={} and "
|
||||
"maxval={}".format(dtype, minval, maxval))
|
||||
elif maxval is None:
|
||||
maxval = 1
|
||||
with ops.name_scope(name, "stateful_uniform",
|
||||
[shape, minval, maxval]) as name:
|
||||
shape = _shape_tensor(shape)
|
||||
if dtype.is_integer and minval is None:
|
||||
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
|
||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||
if dtype.is_integer:
|
||||
@ -725,8 +740,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
|
||||
"""Uniform distribution on an integer type's entire range.
|
||||
|
||||
The other method `uniform` only covers the range [minval, maxval), which
|
||||
cannot be `dtype`'s full range because `maxval` is of type `dtype`.
|
||||
This method is the same as setting `minval` and `maxval` to `None` in the
|
||||
`uniform` method.
|
||||
|
||||
Args:
|
||||
shape: the shape of the output.
|
||||
@ -740,9 +755,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
with ops.name_scope(name, "stateful_uniform_full_int",
|
||||
[shape]) as name:
|
||||
shape = _shape_tensor(shape)
|
||||
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||
self.state.handle, self.algorithm, shape=shape,
|
||||
dtype=dtype, name=name)
|
||||
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
|
||||
|
||||
def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
|
||||
"""Outputs random values from a binomial distribution.
|
||||
|
@ -748,6 +748,18 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(2, len(local_results))
|
||||
self.assertAllDifferent(local_results)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testUniformFullInt(self):
|
||||
"""Tests full-range int uniform.
|
||||
"""
|
||||
shape = [3, 4]
|
||||
dtype = dtypes.int32
|
||||
g = random.Generator.from_seed(1)
|
||||
r1 = g.uniform(shape=shape, dtype=dtype, minval=None)
|
||||
g = random.Generator.from_seed(1)
|
||||
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
|
||||
self.assertAllEqual(r1, r2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -66,7 +66,7 @@ def stateless_random_uniform(shape,
|
||||
`maxval - minval` significantly smaller than the range of the output (either
|
||||
`2**32` or `2**64`).
|
||||
|
||||
For full full-range (i.e. inclusive of both max and min) random integers, pass
|
||||
For full-range (i.e. inclusive of both max and min) random integers, pass
|
||||
`minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
|
||||
either both `minval` and `maxval` must be `None` or neither may be `None`. For
|
||||
example:
|
||||
|
Loading…
Reference in New Issue
Block a user