From 3e95e9f03aea5592eef2bcfa6eb8ec94cbb73db3 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Fri, 6 Mar 2020 16:01:57 -0800 Subject: [PATCH] Call uniform_full_int in tf.random.Generator.uniform when dtype is integer and both minval and maxval are None, to align with stateless_random_uniform. PiperOrigin-RevId: 299455876 Change-Id: I7274a15bef63ea349631e2f3053c71af387c2ca1 --- tensorflow/python/ops/stateful_random_ops.py | 37 +++++++++++++------ .../python/ops/stateful_random_ops_test.py | 12 ++++++ tensorflow/python/ops/stateless_random_ops.py | 2 +- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/ops/stateful_random_ops.py b/tensorflow/python/ops/stateful_random_ops.py index 33876c184ee..9e8eba2e789 100644 --- a/tensorflow/python/ops/stateful_random_ops.py +++ b/tensorflow/python/ops/stateful_random_ops.py @@ -666,6 +666,11 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): return gen_stateful_random_ops.stateful_uniform( self.state.handle, self.algorithm, shape=shape, dtype=dtype) + def _uniform_full_int(self, shape, dtype, name=None): + return gen_stateful_random_ops.stateful_uniform_full_int( + self.state.handle, self.algorithm, shape=shape, + dtype=dtype, name=name) + def uniform(self, shape, minval=0, maxval=None, dtype=dtypes.float32, name=None): """Outputs random values from a uniform distribution. @@ -684,17 +689,22 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): `maxval - minval` significantly smaller than the range of the output (either `2**32` or `2**64`). + For full-range random integers, pass `minval=None` and `maxval=None` with an + integer `dtype` (for integer dtypes, `minval` and `maxval` must be both + `None` or both not `None`). + Args: shape: A 1-D integer Tensor or Python array. The shape of the output tensor. minval: A Tensor or Python value of type `dtype`, broadcastable with `shape` (for integer types, broadcasting is not supported, so it needs - to be a scalar). The lower bound on the range of random values to - generate. Defaults to 0. + to be a scalar). The lower bound (included) on the range of random + values to generate. Pass `None` for full-range integers. Defaults to 0. maxval: A Tensor or Python value of type `dtype`, broadcastable with `shape` (for integer types, broadcasting is not supported, so it needs - to be a scalar). The upper bound on the range of random values to - generate. Defaults to 1 if `dtype` is floating point. + to be a scalar). The upper bound (excluded) on the range of random + values to generate. Pass `None` for full-range integers. Defaults to 1 + if `dtype` is floating point. dtype: The type of the output. name: A name for the operation (optional). @@ -705,13 +715,18 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): ValueError: If `dtype` is integral and `maxval` is not specified. """ dtype = dtypes.as_dtype(dtype) - if maxval is None: - if dtype.is_integer: - raise ValueError("Must specify maxval for integer dtype %r" % dtype) + if dtype.is_integer: + if (minval is None) != (maxval is None): + raise ValueError("For integer dtype {}, minval and maxval must be both " + "`None` or both non-`None`; got minval={} and " + "maxval={}".format(dtype, minval, maxval)) + elif maxval is None: maxval = 1 with ops.name_scope(name, "stateful_uniform", [shape, minval, maxval]) as name: shape = _shape_tensor(shape) + if dtype.is_integer and minval is None: + return self._uniform_full_int(shape=shape, dtype=dtype, name=name) minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") if dtype.is_integer: @@ -725,8 +740,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None): """Uniform distribution on an integer type's entire range. - The other method `uniform` only covers the range [minval, maxval), which - cannot be `dtype`'s full range because `maxval` is of type `dtype`. + This method is the same as setting `minval` and `maxval` to `None` in the + `uniform` method. Args: shape: the shape of the output. @@ -740,9 +755,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): with ops.name_scope(name, "stateful_uniform_full_int", [shape]) as name: shape = _shape_tensor(shape) - return gen_stateful_random_ops.stateful_uniform_full_int( - self.state.handle, self.algorithm, shape=shape, - dtype=dtype, name=name) + return self._uniform_full_int(shape=shape, dtype=dtype, name=name) def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None): """Outputs random values from a binomial distribution. diff --git a/tensorflow/python/ops/stateful_random_ops_test.py b/tensorflow/python/ops/stateful_random_ops_test.py index be91a4ca479..2389a068854 100644 --- a/tensorflow/python/ops/stateful_random_ops_test.py +++ b/tensorflow/python/ops/stateful_random_ops_test.py @@ -748,6 +748,18 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(2, len(local_results)) self.assertAllDifferent(local_results) + @test_util.run_v2_only + def testUniformFullInt(self): + """Tests full-range int uniform. + """ + shape = [3, 4] + dtype = dtypes.int32 + g = random.Generator.from_seed(1) + r1 = g.uniform(shape=shape, dtype=dtype, minval=None) + g = random.Generator.from_seed(1) + r2 = g.uniform_full_int(shape=shape, dtype=dtype) + self.assertAllEqual(r1, r2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/stateless_random_ops.py b/tensorflow/python/ops/stateless_random_ops.py index 316fcf1928e..05386f0a5b9 100644 --- a/tensorflow/python/ops/stateless_random_ops.py +++ b/tensorflow/python/ops/stateless_random_ops.py @@ -66,7 +66,7 @@ def stateless_random_uniform(shape, `maxval - minval` significantly smaller than the range of the output (either `2**32` or `2**64`). - For full full-range (i.e. inclusive of both max and min) random integers, pass + For full-range (i.e. inclusive of both max and min) random integers, pass `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype either both `minval` and `maxval` must be `None` or neither may be `None`. For example: