Adds examples in stateful_random_ops.py, improves some docstrings, and also adds missing export for algorithm enum.
PiperOrigin-RevId: 285845389 Change-Id: Iaa60466151ba761a6f5a1ef6cd8b34cbbfdde5b1
This commit is contained in:
parent
1d61880a20
commit
832ba2b25a
@ -18,9 +18,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import sys
|
import enum # pylint: disable=g-bad-import-order
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -58,15 +59,21 @@ SEED_SIZE = 16 # in units of SEED_TYPE
|
|||||||
|
|
||||||
STATE_TYPE = SEED_TYPE
|
STATE_TYPE = SEED_TYPE
|
||||||
ALGORITHM_TYPE = STATE_TYPE
|
ALGORITHM_TYPE = STATE_TYPE
|
||||||
RNG_ALG_PHILOX = 1
|
|
||||||
RNG_ALG_THREEFRY = 2
|
|
||||||
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
|
|
||||||
|
|
||||||
|
|
||||||
PHILOX_STATE_SIZE = 3
|
PHILOX_STATE_SIZE = 3
|
||||||
THREEFRY_STATE_SIZE = 2
|
THREEFRY_STATE_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("random.experimental.Algorithm")
|
||||||
|
class Algorithm(enum.Enum):
|
||||||
|
PHILOX = 1
|
||||||
|
THREEFRY = 2
|
||||||
|
|
||||||
|
|
||||||
|
RNG_ALG_PHILOX = Algorithm.PHILOX.value
|
||||||
|
RNG_ALG_THREEFRY = Algorithm.THREEFRY.value
|
||||||
|
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
|
||||||
|
|
||||||
|
|
||||||
def non_deterministic_ints(shape, dtype=dtypes.int64):
|
def non_deterministic_ints(shape, dtype=dtypes.int64):
|
||||||
"""Non-deterministically generates some integers.
|
"""Non-deterministically generates some integers.
|
||||||
|
|
||||||
@ -100,8 +107,7 @@ def _make_1d_state(state_size, seed):
|
|||||||
Returns:
|
Returns:
|
||||||
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
|
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
|
||||||
"""
|
"""
|
||||||
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
|
if isinstance(seed, six.integer_types):
|
||||||
if isinstance(seed, int_types):
|
|
||||||
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
|
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
|
||||||
ls = []
|
ls = []
|
||||||
for _ in range(state_size):
|
for _ in range(state_size):
|
||||||
@ -149,18 +155,56 @@ def _make_state_from_seed(seed, alg):
|
|||||||
return _make_1d_state(_get_state_size(alg), seed)
|
return _make_1d_state(_get_state_size(alg), seed)
|
||||||
|
|
||||||
|
|
||||||
@tf_export("random.experimental.create_rng_state")
|
def _convert_alg_to_int(alg):
|
||||||
def create_rng_state(seed, algorithm):
|
"""Converts algorithm to an integer.
|
||||||
"""Creates a RNG state.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: an integer or 1-D tensor.
|
alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
|
||||||
algorithm: an integer representing the RNG algorithm.
|
strings are "philox" and "threefry".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a 1-D tensor whose size depends on the algorithm.
|
An integer, unless the input is a Tensor in which case a Tensor is returned.
|
||||||
"""
|
"""
|
||||||
return _make_state_from_seed(seed, algorithm)
|
if isinstance(alg, six.integer_types):
|
||||||
|
return alg
|
||||||
|
if isinstance(alg, Algorithm):
|
||||||
|
return alg.value
|
||||||
|
if isinstance(alg, ops.Tensor):
|
||||||
|
return alg
|
||||||
|
if isinstance(alg, str):
|
||||||
|
if alg == "philox":
|
||||||
|
return RNG_ALG_PHILOX
|
||||||
|
elif alg == "threefry":
|
||||||
|
return RNG_ALG_THREEFRY
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown algorithm name: %s" % alg)
|
||||||
|
else:
|
||||||
|
raise TypeError("Can't convert algorithm %s of type %s to int" %
|
||||||
|
(alg, type(alg)))
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("random.experimental.create_rng_state")
|
||||||
|
def create_rng_state(seed, alg):
|
||||||
|
"""Creates a RNG state from an integer or a vector.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> tf.random.experimental.create_rng_state(
|
||||||
|
... 1234, "philox")
|
||||||
|
array([1234, 0, 0])
|
||||||
|
>>> tf.random.experimental.create_rng_state(
|
||||||
|
... [12, 34], "threefry")
|
||||||
|
array([12, 34])
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: an integer or 1-D numpy array.
|
||||||
|
alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a 1-D numpy array whose size depends on the algorithm.
|
||||||
|
"""
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
|
return _make_state_from_seed(seed, alg)
|
||||||
|
|
||||||
|
|
||||||
def _shape_tensor(shape):
|
def _shape_tensor(shape):
|
||||||
@ -239,12 +283,55 @@ def _create_variable(*args, **kwargs):
|
|||||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||||
"""Random-number generator.
|
"""Random-number generator.
|
||||||
|
|
||||||
It uses Variable to manage its internal state, and allows choosing an
|
Example:
|
||||||
Random-Number-Generation (RNG) algorithm.
|
|
||||||
|
Creating a generator from a seed:
|
||||||
|
|
||||||
|
>>> g = tf.random.experimental.Generator.from_seed(1234)
|
||||||
|
>>> g.normal(shape=(2, 3))
|
||||||
|
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
|
||||||
|
array([[ 0.9356609 , 1.0854305 , -0.93788373],
|
||||||
|
[-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)>
|
||||||
|
|
||||||
|
Creating a generator from a non-deterministic state:
|
||||||
|
|
||||||
|
>>> g = tf.random.experimental.Generator.from_non_deterministic_state()
|
||||||
|
>>> g.normal(shape=(2, 3))
|
||||||
|
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||||
|
|
||||||
|
All the constructors allow explicitly choosing an Random-Number-Generation
|
||||||
|
(RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
|
||||||
|
example:
|
||||||
|
|
||||||
|
>>> g = tf.random.experimental.Generator.from_seed(123, alg="philox")
|
||||||
|
>>> g.normal(shape=(2, 3))
|
||||||
|
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
|
||||||
|
array([[ 0.8673864 , -0.29899067, -0.9310337 ],
|
||||||
|
[-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)>
|
||||||
|
|
||||||
CPU, GPU and TPU with the same algorithm and seed will generate the same
|
CPU, GPU and TPU with the same algorithm and seed will generate the same
|
||||||
integer random numbers. Float-point results (such as the output of `normal`)
|
integer random numbers. Float-point results (such as the output of `normal`)
|
||||||
may have small numerical discrepancies between CPU and GPU.
|
may have small numerical discrepancies between different devices.
|
||||||
|
|
||||||
|
This class uses a `tf.Variable` to manage its internal state. Every time
|
||||||
|
random numbers are generated, the state of the generator will change. For
|
||||||
|
example:
|
||||||
|
|
||||||
|
>>> g = tf.random.experimental.Generator.from_seed(1234)
|
||||||
|
>>> g.state
|
||||||
|
<tf.Variable ... numpy=array([1234, 0, 0])>
|
||||||
|
>>> g.normal(shape=(2, 3))
|
||||||
|
<...>
|
||||||
|
>>> g.state
|
||||||
|
<tf.Variable ... numpy=array([2770, 0, 0])>
|
||||||
|
|
||||||
|
The shape of the state is algorithm-specific.
|
||||||
|
|
||||||
|
There is also a global generator:
|
||||||
|
|
||||||
|
>>> g = tf.random.experimental.get_global_generator()
|
||||||
|
>>> g.normal(shape=(2, 3))
|
||||||
|
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, copy_from=None, state=None, alg=None):
|
def __init__(self, copy_from=None, state=None, alg=None):
|
||||||
@ -263,11 +350,13 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
RNG, whose length and semantics are algorithm-specific. If it's a
|
RNG, whose length and semantics are algorithm-specific. If it's a
|
||||||
variable, the generator will reuse it instead of creating a new
|
variable, the generator will reuse it instead of creating a new
|
||||||
variable.
|
variable.
|
||||||
alg: the RNG algorithm. Possible values are `RNG_ALG_PHILOX` for the
|
alg: the RNG algorithm. Possible values are
|
||||||
Philox algorithm and `RNG_ALG_THREEFRY` for the ThreeFry
|
`tf.random.experimental.Algorithm.PHILOX` for the Philox algorithm and
|
||||||
algorithm (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
|
`tf.random.experimental.Algorithm.THREEFRY` for the ThreeFry algorithm
|
||||||
|
(see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
|
||||||
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
|
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
|
||||||
Note `RNG_ALG_PHILOX` guarantees the same numbers are produced (given
|
The string names `"philox"` and `"threefry"` can also be used.
|
||||||
|
Note `PHILOX` guarantees the same numbers are produced (given
|
||||||
the same random state) across all architextures (CPU, GPU, XLA etc).
|
the same random state) across all architextures (CPU, GPU, XLA etc).
|
||||||
|
|
||||||
Throws:
|
Throws:
|
||||||
@ -287,6 +376,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
assert alg is not None and state is not None
|
assert alg is not None and state is not None
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
if isinstance(state, variables.Variable):
|
if isinstance(state, variables.Variable):
|
||||||
_check_state_shape(state.shape, alg)
|
_check_state_shape(state.shape, alg)
|
||||||
self._state_var = state
|
self._state_var = state
|
||||||
@ -350,6 +440,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
if alg is None:
|
if alg is None:
|
||||||
# TODO(wangpeng): more sophisticated algorithm selection
|
# TODO(wangpeng): more sophisticated algorithm selection
|
||||||
alg = DEFAULT_ALGORITHM
|
alg = DEFAULT_ALGORITHM
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
state = create_rng_state(seed, alg)
|
state = create_rng_state(seed, alg)
|
||||||
return cls(state=state, alg=alg)
|
return cls(state=state, alg=alg)
|
||||||
|
|
||||||
@ -377,6 +468,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
if alg is None:
|
if alg is None:
|
||||||
# TODO(wangpeng): more sophisticated algorithm selection
|
# TODO(wangpeng): more sophisticated algorithm selection
|
||||||
alg = DEFAULT_ALGORITHM
|
alg = DEFAULT_ALGORITHM
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
||||||
dtype=SEED_TYPE)
|
dtype=SEED_TYPE)
|
||||||
return cls(state=state, alg=alg)
|
return cls(state=state, alg=alg)
|
||||||
@ -408,6 +500,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
"""
|
"""
|
||||||
counter = _convert_to_state_tensor(counter)
|
counter = _convert_to_state_tensor(counter)
|
||||||
key = _convert_to_state_tensor(key)
|
key = _convert_to_state_tensor(key)
|
||||||
|
alg = _convert_alg_to_int(alg)
|
||||||
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
||||||
key.shape.assert_is_compatible_with([])
|
key.shape.assert_is_compatible_with([])
|
||||||
key = array_ops.reshape(key, [1])
|
key = array_ops.reshape(key, [1])
|
||||||
@ -466,7 +559,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def algorithm(self):
|
def algorithm(self):
|
||||||
"""The RNG algorithm."""
|
"""The RNG algorithm id (a Python integer or scalar integer Tensor)."""
|
||||||
return self._alg
|
return self._alg
|
||||||
|
|
||||||
def _standard_normal(self, shape, dtype):
|
def _standard_normal(self, shape, dtype):
|
||||||
@ -806,6 +899,16 @@ global_generator = None
|
|||||||
|
|
||||||
@tf_export("random.experimental.get_global_generator")
|
@tf_export("random.experimental.get_global_generator")
|
||||||
def get_global_generator():
|
def get_global_generator():
|
||||||
|
"""Retrieves the global generator.
|
||||||
|
|
||||||
|
This function will create the global generator the first time it is called,
|
||||||
|
and the generator will be placed at the default device at that time, so one
|
||||||
|
needs to be careful when this function is first called. Using a generator
|
||||||
|
placed on a less-ideal device will incur performance regression.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The global `tf.random.experimental.Generator` object.
|
||||||
|
"""
|
||||||
global global_generator
|
global global_generator
|
||||||
if global_generator is None:
|
if global_generator is None:
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
|
@ -197,6 +197,17 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
|||||||
check_results(expected_normal1, f(constructor))
|
check_results(expected_normal1, f(constructor))
|
||||||
check_results(expected_normal2, f(constructor))
|
check_results(expected_normal2, f(constructor))
|
||||||
|
|
||||||
|
@parameterized.parameters([
|
||||||
|
("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX),
|
||||||
|
("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)])
|
||||||
|
@test_util.run_v2_only
|
||||||
|
def testAlg(self, name, int_id, enum_id):
|
||||||
|
g_by_name = random.Generator.from_seed(1234, name)
|
||||||
|
g_by_int = random.Generator.from_seed(1234, int_id)
|
||||||
|
g_by_enum = random.Generator.from_seed(1234, enum_id)
|
||||||
|
self.assertEqual(g_by_name.algorithm, g_by_int.algorithm)
|
||||||
|
self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm)
|
||||||
|
|
||||||
@test_util.run_v2_only
|
@test_util.run_v2_only
|
||||||
def testGeneratorCreationWithVar(self):
|
def testGeneratorCreationWithVar(self):
|
||||||
"""Tests creating generator with a variable.
|
"""Tests creating generator with a variable.
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
path: "tensorflow.random.experimental.Algorithm"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<enum \'Algorithm\'>"
|
||||||
|
member {
|
||||||
|
name: "PHILOX"
|
||||||
|
mtype: "<enum \'Algorithm\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "THREEFRY"
|
||||||
|
mtype: "<enum \'Algorithm\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -1,12 +1,16 @@
|
|||||||
path: "tensorflow.random.experimental"
|
path: "tensorflow.random.experimental"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "Algorithm"
|
||||||
|
mtype: "<class \'enum.EnumMeta\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "Generator"
|
name: "Generator"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "create_rng_state"
|
name: "create_rng_state"
|
||||||
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_global_generator"
|
name: "get_global_generator"
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
path: "tensorflow.random.experimental.Algorithm"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<enum \'Algorithm\'>"
|
||||||
|
member {
|
||||||
|
name: "PHILOX"
|
||||||
|
mtype: "<enum \'Algorithm\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "THREEFRY"
|
||||||
|
mtype: "<enum \'Algorithm\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -1,12 +1,16 @@
|
|||||||
path: "tensorflow.random.experimental"
|
path: "tensorflow.random.experimental"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "Algorithm"
|
||||||
|
mtype: "<class \'enum.EnumMeta\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "Generator"
|
name: "Generator"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "create_rng_state"
|
name: "create_rng_state"
|
||||||
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_global_generator"
|
name: "get_global_generator"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user