Call uniform_full_int in tf.random.Generator.uniform when dtype is integer and both minval and maxval are None, to align with stateless_random_uniform.
PiperOrigin-RevId: 299455876 Change-Id: I7274a15bef63ea349631e2f3053c71af387c2ca1
This commit is contained in:
parent
1e2122cf33
commit
3e95e9f03a
@ -666,6 +666,11 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
return gen_stateful_random_ops.stateful_uniform(
|
return gen_stateful_random_ops.stateful_uniform(
|
||||||
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
|
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
def _uniform_full_int(self, shape, dtype, name=None):
|
||||||
|
return gen_stateful_random_ops.stateful_uniform_full_int(
|
||||||
|
self.state.handle, self.algorithm, shape=shape,
|
||||||
|
dtype=dtype, name=name)
|
||||||
|
|
||||||
def uniform(self, shape, minval=0, maxval=None,
|
def uniform(self, shape, minval=0, maxval=None,
|
||||||
dtype=dtypes.float32, name=None):
|
dtype=dtypes.float32, name=None):
|
||||||
"""Outputs random values from a uniform distribution.
|
"""Outputs random values from a uniform distribution.
|
||||||
@ -684,17 +689,22 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
`maxval - minval` significantly smaller than the range of the output (either
|
`maxval - minval` significantly smaller than the range of the output (either
|
||||||
`2**32` or `2**64`).
|
`2**32` or `2**64`).
|
||||||
|
|
||||||
|
For full-range random integers, pass `minval=None` and `maxval=None` with an
|
||||||
|
integer `dtype` (for integer dtypes, `minval` and `maxval` must be both
|
||||||
|
`None` or both not `None`).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: A 1-D integer Tensor or Python array. The shape of the output
|
shape: A 1-D integer Tensor or Python array. The shape of the output
|
||||||
tensor.
|
tensor.
|
||||||
minval: A Tensor or Python value of type `dtype`, broadcastable with
|
minval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||||
`shape` (for integer types, broadcasting is not supported, so it needs
|
`shape` (for integer types, broadcasting is not supported, so it needs
|
||||||
to be a scalar). The lower bound on the range of random values to
|
to be a scalar). The lower bound (included) on the range of random
|
||||||
generate. Defaults to 0.
|
values to generate. Pass `None` for full-range integers. Defaults to 0.
|
||||||
maxval: A Tensor or Python value of type `dtype`, broadcastable with
|
maxval: A Tensor or Python value of type `dtype`, broadcastable with
|
||||||
`shape` (for integer types, broadcasting is not supported, so it needs
|
`shape` (for integer types, broadcasting is not supported, so it needs
|
||||||
to be a scalar). The upper bound on the range of random values to
|
to be a scalar). The upper bound (excluded) on the range of random
|
||||||
generate. Defaults to 1 if `dtype` is floating point.
|
values to generate. Pass `None` for full-range integers. Defaults to 1
|
||||||
|
if `dtype` is floating point.
|
||||||
dtype: The type of the output.
|
dtype: The type of the output.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
@ -705,13 +715,18 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
ValueError: If `dtype` is integral and `maxval` is not specified.
|
ValueError: If `dtype` is integral and `maxval` is not specified.
|
||||||
"""
|
"""
|
||||||
dtype = dtypes.as_dtype(dtype)
|
dtype = dtypes.as_dtype(dtype)
|
||||||
if maxval is None:
|
if dtype.is_integer:
|
||||||
if dtype.is_integer:
|
if (minval is None) != (maxval is None):
|
||||||
raise ValueError("Must specify maxval for integer dtype %r" % dtype)
|
raise ValueError("For integer dtype {}, minval and maxval must be both "
|
||||||
|
"`None` or both non-`None`; got minval={} and "
|
||||||
|
"maxval={}".format(dtype, minval, maxval))
|
||||||
|
elif maxval is None:
|
||||||
maxval = 1
|
maxval = 1
|
||||||
with ops.name_scope(name, "stateful_uniform",
|
with ops.name_scope(name, "stateful_uniform",
|
||||||
[shape, minval, maxval]) as name:
|
[shape, minval, maxval]) as name:
|
||||||
shape = _shape_tensor(shape)
|
shape = _shape_tensor(shape)
|
||||||
|
if dtype.is_integer and minval is None:
|
||||||
|
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
|
||||||
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
|
||||||
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
|
||||||
if dtype.is_integer:
|
if dtype.is_integer:
|
||||||
@ -725,8 +740,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
|
def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
|
||||||
"""Uniform distribution on an integer type's entire range.
|
"""Uniform distribution on an integer type's entire range.
|
||||||
|
|
||||||
The other method `uniform` only covers the range [minval, maxval), which
|
This method is the same as setting `minval` and `maxval` to `None` in the
|
||||||
cannot be `dtype`'s full range because `maxval` is of type `dtype`.
|
`uniform` method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape: the shape of the output.
|
shape: the shape of the output.
|
||||||
@ -740,9 +755,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
with ops.name_scope(name, "stateful_uniform_full_int",
|
with ops.name_scope(name, "stateful_uniform_full_int",
|
||||||
[shape]) as name:
|
[shape]) as name:
|
||||||
shape = _shape_tensor(shape)
|
shape = _shape_tensor(shape)
|
||||||
return gen_stateful_random_ops.stateful_uniform_full_int(
|
return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
|
||||||
self.state.handle, self.algorithm, shape=shape,
|
|
||||||
dtype=dtype, name=name)
|
|
||||||
|
|
||||||
def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
|
def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
|
||||||
"""Outputs random values from a binomial distribution.
|
"""Outputs random values from a binomial distribution.
|
||||||
|
@ -748,6 +748,18 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(2, len(local_results))
|
self.assertAllEqual(2, len(local_results))
|
||||||
self.assertAllDifferent(local_results)
|
self.assertAllDifferent(local_results)
|
||||||
|
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testUniformFullInt(self):
|
||||||
|
"""Tests full-range int uniform.
|
||||||
|
"""
|
||||||
|
shape = [3, 4]
|
||||||
|
dtype = dtypes.int32
|
||||||
|
g = random.Generator.from_seed(1)
|
||||||
|
r1 = g.uniform(shape=shape, dtype=dtype, minval=None)
|
||||||
|
g = random.Generator.from_seed(1)
|
||||||
|
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
|
||||||
|
self.assertAllEqual(r1, r2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -66,7 +66,7 @@ def stateless_random_uniform(shape,
|
|||||||
`maxval - minval` significantly smaller than the range of the output (either
|
`maxval - minval` significantly smaller than the range of the output (either
|
||||||
`2**32` or `2**64`).
|
`2**32` or `2**64`).
|
||||||
|
|
||||||
For full full-range (i.e. inclusive of both max and min) random integers, pass
|
For full-range (i.e. inclusive of both max and min) random integers, pass
|
||||||
`minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
|
`minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
|
||||||
either both `minval` and `maxval` must be `None` or neither may be `None`. For
|
either both `minval` and `maxval` must be `None` or neither may be `None`. For
|
||||||
example:
|
example:
|
||||||
|
Loading…
Reference in New Issue
Block a user