Adds examples in stateful_random_ops.py, improves some docstrings, and also adds missing export for algorithm enum.
PiperOrigin-RevId: 285845389 Change-Id: Iaa60466151ba761a6f5a1ef6cd8b34cbbfdde5b1
This commit is contained in:
parent
1d61880a20
commit
832ba2b25a
@ -18,9 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import enum # pylint: disable=g-bad-import-order
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.eager import context
|
||||
@ -58,15 +59,21 @@ SEED_SIZE = 16 # in units of SEED_TYPE
|
||||
|
||||
STATE_TYPE = SEED_TYPE
|
||||
ALGORITHM_TYPE = STATE_TYPE
|
||||
RNG_ALG_PHILOX = 1
|
||||
RNG_ALG_THREEFRY = 2
|
||||
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
|
||||
|
||||
|
||||
PHILOX_STATE_SIZE = 3
|
||||
THREEFRY_STATE_SIZE = 2
|
||||
|
||||
|
||||
@tf_export("random.experimental.Algorithm")
|
||||
class Algorithm(enum.Enum):
|
||||
PHILOX = 1
|
||||
THREEFRY = 2
|
||||
|
||||
|
||||
RNG_ALG_PHILOX = Algorithm.PHILOX.value
|
||||
RNG_ALG_THREEFRY = Algorithm.THREEFRY.value
|
||||
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
|
||||
|
||||
|
||||
def non_deterministic_ints(shape, dtype=dtypes.int64):
|
||||
"""Non-deterministically generates some integers.
|
||||
|
||||
@ -100,8 +107,7 @@ def _make_1d_state(state_size, seed):
|
||||
Returns:
|
||||
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
|
||||
"""
|
||||
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
|
||||
if isinstance(seed, int_types):
|
||||
if isinstance(seed, six.integer_types):
|
||||
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
|
||||
ls = []
|
||||
for _ in range(state_size):
|
||||
@ -149,18 +155,56 @@ def _make_state_from_seed(seed, alg):
|
||||
return _make_1d_state(_get_state_size(alg), seed)
|
||||
|
||||
|
||||
@tf_export("random.experimental.create_rng_state")
|
||||
def create_rng_state(seed, algorithm):
|
||||
"""Creates a RNG state.
|
||||
def _convert_alg_to_int(alg):
|
||||
"""Converts algorithm to an integer.
|
||||
|
||||
Args:
|
||||
seed: an integer or 1-D tensor.
|
||||
algorithm: an integer representing the RNG algorithm.
|
||||
alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
|
||||
strings are "philox" and "threefry".
|
||||
|
||||
Returns:
|
||||
a 1-D tensor whose size depends on the algorithm.
|
||||
An integer, unless the input is a Tensor in which case a Tensor is returned.
|
||||
"""
|
||||
return _make_state_from_seed(seed, algorithm)
|
||||
if isinstance(alg, six.integer_types):
|
||||
return alg
|
||||
if isinstance(alg, Algorithm):
|
||||
return alg.value
|
||||
if isinstance(alg, ops.Tensor):
|
||||
return alg
|
||||
if isinstance(alg, str):
|
||||
if alg == "philox":
|
||||
return RNG_ALG_PHILOX
|
||||
elif alg == "threefry":
|
||||
return RNG_ALG_THREEFRY
|
||||
else:
|
||||
raise ValueError("Unknown algorithm name: %s" % alg)
|
||||
else:
|
||||
raise TypeError("Can't convert algorithm %s of type %s to int" %
|
||||
(alg, type(alg)))
|
||||
|
||||
|
||||
@tf_export("random.experimental.create_rng_state")
|
||||
def create_rng_state(seed, alg):
|
||||
"""Creates a RNG state from an integer or a vector.
|
||||
|
||||
Example:
|
||||
|
||||
>>> tf.random.experimental.create_rng_state(
|
||||
... 1234, "philox")
|
||||
array([1234, 0, 0])
|
||||
>>> tf.random.experimental.create_rng_state(
|
||||
... [12, 34], "threefry")
|
||||
array([12, 34])
|
||||
|
||||
Args:
|
||||
seed: an integer or 1-D numpy array.
|
||||
alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.
|
||||
|
||||
Returns:
|
||||
a 1-D numpy array whose size depends on the algorithm.
|
||||
"""
|
||||
alg = _convert_alg_to_int(alg)
|
||||
return _make_state_from_seed(seed, alg)
|
||||
|
||||
|
||||
def _shape_tensor(shape):
|
||||
@ -239,12 +283,55 @@ def _create_variable(*args, **kwargs):
|
||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
"""Random-number generator.
|
||||
|
||||
It uses Variable to manage its internal state, and allows choosing an
|
||||
Random-Number-Generation (RNG) algorithm.
|
||||
Example:
|
||||
|
||||
Creating a generator from a seed:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_seed(1234)
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
|
||||
array([[ 0.9356609 , 1.0854305 , -0.93788373],
|
||||
[-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)>
|
||||
|
||||
Creating a generator from a non-deterministic state:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_non_deterministic_state()
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||
|
||||
All the constructors allow explicitly choosing an Random-Number-Generation
|
||||
(RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
|
||||
example:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_seed(123, alg="philox")
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
|
||||
array([[ 0.8673864 , -0.29899067, -0.9310337 ],
|
||||
[-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)>
|
||||
|
||||
CPU, GPU and TPU with the same algorithm and seed will generate the same
|
||||
integer random numbers. Float-point results (such as the output of `normal`)
|
||||
may have small numerical discrepancies between CPU and GPU.
|
||||
may have small numerical discrepancies between different devices.
|
||||
|
||||
This class uses a `tf.Variable` to manage its internal state. Every time
|
||||
random numbers are generated, the state of the generator will change. For
|
||||
example:
|
||||
|
||||
>>> g = tf.random.experimental.Generator.from_seed(1234)
|
||||
>>> g.state
|
||||
<tf.Variable ... numpy=array([1234, 0, 0])>
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<...>
|
||||
>>> g.state
|
||||
<tf.Variable ... numpy=array([2770, 0, 0])>
|
||||
|
||||
The shape of the state is algorithm-specific.
|
||||
|
||||
There is also a global generator:
|
||||
|
||||
>>> g = tf.random.experimental.get_global_generator()
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||
"""
|
||||
|
||||
def __init__(self, copy_from=None, state=None, alg=None):
|
||||
@ -263,11 +350,13 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
RNG, whose length and semantics are algorithm-specific. If it's a
|
||||
variable, the generator will reuse it instead of creating a new
|
||||
variable.
|
||||
alg: the RNG algorithm. Possible values are `RNG_ALG_PHILOX` for the
|
||||
Philox algorithm and `RNG_ALG_THREEFRY` for the ThreeFry
|
||||
algorithm (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
|
||||
alg: the RNG algorithm. Possible values are
|
||||
`tf.random.experimental.Algorithm.PHILOX` for the Philox algorithm and
|
||||
`tf.random.experimental.Algorithm.THREEFRY` for the ThreeFry algorithm
|
||||
(see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
|
||||
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
|
||||
Note `RNG_ALG_PHILOX` guarantees the same numbers are produced (given
|
||||
The string names `"philox"` and `"threefry"` can also be used.
|
||||
Note `PHILOX` guarantees the same numbers are produced (given
|
||||
the same random state) across all architextures (CPU, GPU, XLA etc).
|
||||
|
||||
Throws:
|
||||
@ -287,6 +376,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
|
||||
else:
|
||||
assert alg is not None and state is not None
|
||||
alg = _convert_alg_to_int(alg)
|
||||
if isinstance(state, variables.Variable):
|
||||
_check_state_shape(state.shape, alg)
|
||||
self._state_var = state
|
||||
@ -350,6 +440,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
if alg is None:
|
||||
# TODO(wangpeng): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
alg = _convert_alg_to_int(alg)
|
||||
state = create_rng_state(seed, alg)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
@ -377,6 +468,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
if alg is None:
|
||||
# TODO(wangpeng): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
alg = _convert_alg_to_int(alg)
|
||||
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
||||
dtype=SEED_TYPE)
|
||||
return cls(state=state, alg=alg)
|
||||
@ -408,6 +500,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
"""
|
||||
counter = _convert_to_state_tensor(counter)
|
||||
key = _convert_to_state_tensor(key)
|
||||
alg = _convert_alg_to_int(alg)
|
||||
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
||||
key.shape.assert_is_compatible_with([])
|
||||
key = array_ops.reshape(key, [1])
|
||||
@ -466,7 +559,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
|
||||
@property
|
||||
def algorithm(self):
|
||||
"""The RNG algorithm."""
|
||||
"""The RNG algorithm id (a Python integer or scalar integer Tensor)."""
|
||||
return self._alg
|
||||
|
||||
def _standard_normal(self, shape, dtype):
|
||||
@ -806,6 +899,16 @@ global_generator = None
|
||||
|
||||
@tf_export("random.experimental.get_global_generator")
|
||||
def get_global_generator():
|
||||
"""Retrieves the global generator.
|
||||
|
||||
This function will create the global generator the first time it is called,
|
||||
and the generator will be placed at the default device at that time, so one
|
||||
needs to be careful when this function is first called. Using a generator
|
||||
placed on a less-ideal device will incur performance regression.
|
||||
|
||||
Returns:
|
||||
The global `tf.random.experimental.Generator` object.
|
||||
"""
|
||||
global global_generator
|
||||
if global_generator is None:
|
||||
with ops.init_scope():
|
||||
|
@ -197,6 +197,17 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
check_results(expected_normal1, f(constructor))
|
||||
check_results(expected_normal2, f(constructor))
|
||||
|
||||
@parameterized.parameters([
|
||||
("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX),
|
||||
("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)])
|
||||
@test_util.run_v2_only
|
||||
def testAlg(self, name, int_id, enum_id):
|
||||
g_by_name = random.Generator.from_seed(1234, name)
|
||||
g_by_int = random.Generator.from_seed(1234, int_id)
|
||||
g_by_enum = random.Generator.from_seed(1234, enum_id)
|
||||
self.assertEqual(g_by_name.algorithm, g_by_int.algorithm)
|
||||
self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGeneratorCreationWithVar(self):
|
||||
"""Tests creating generator with a variable.
|
||||
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.random.experimental.Algorithm"
|
||||
tf_class {
|
||||
is_instance: "<enum \'Algorithm\'>"
|
||||
member {
|
||||
name: "PHILOX"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
member {
|
||||
name: "THREEFRY"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
}
|
@ -1,12 +1,16 @@
|
||||
path: "tensorflow.random.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Algorithm"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "Generator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "create_rng_state"
|
||||
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_global_generator"
|
||||
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.random.experimental.Algorithm"
|
||||
tf_class {
|
||||
is_instance: "<enum \'Algorithm\'>"
|
||||
member {
|
||||
name: "PHILOX"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
member {
|
||||
name: "THREEFRY"
|
||||
mtype: "<enum \'Algorithm\'>"
|
||||
}
|
||||
}
|
@ -1,12 +1,16 @@
|
||||
path: "tensorflow.random.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "Algorithm"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "Generator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "create_rng_state"
|
||||
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_global_generator"
|
||||
|
Loading…
Reference in New Issue
Block a user