Adds examples in stateful_random_ops.py, improves some docstrings, and also adds missing export for algorithm enum.

PiperOrigin-RevId: 285845389
Change-Id: Iaa60466151ba761a6f5a1ef6cd8b34cbbfdde5b1
This commit is contained in:
Peng Wang 2019-12-16 14:03:37 -08:00 committed by TensorFlower Gardener
parent 1d61880a20
commit 832ba2b25a
6 changed files with 171 additions and 25 deletions

View File

@ -18,9 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import enum # pylint: disable=g-bad-import-order
import numpy as np
import six
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.eager import context
@ -58,15 +59,21 @@ SEED_SIZE = 16 # in units of SEED_TYPE
STATE_TYPE = SEED_TYPE
ALGORITHM_TYPE = STATE_TYPE
RNG_ALG_PHILOX = 1
RNG_ALG_THREEFRY = 2
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
PHILOX_STATE_SIZE = 3
THREEFRY_STATE_SIZE = 2
@tf_export("random.experimental.Algorithm")
class Algorithm(enum.Enum):
PHILOX = 1
THREEFRY = 2
RNG_ALG_PHILOX = Algorithm.PHILOX.value
RNG_ALG_THREEFRY = Algorithm.THREEFRY.value
DEFAULT_ALGORITHM = RNG_ALG_PHILOX
def non_deterministic_ints(shape, dtype=dtypes.int64):
"""Non-deterministically generates some integers.
@ -100,8 +107,7 @@ def _make_1d_state(state_size, seed):
Returns:
a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
"""
int_types = (int,) if sys.version_info >= (3, 0) else (int, long)
if isinstance(seed, int_types):
if isinstance(seed, six.integer_types):
# chop the Python integer (infinite precision) into chunks of SEED_TYPE
ls = []
for _ in range(state_size):
@ -149,18 +155,56 @@ def _make_state_from_seed(seed, alg):
return _make_1d_state(_get_state_size(alg), seed)
@tf_export("random.experimental.create_rng_state")
def create_rng_state(seed, algorithm):
"""Creates a RNG state.
def _convert_alg_to_int(alg):
"""Converts algorithm to an integer.
Args:
seed: an integer or 1-D tensor.
algorithm: an integer representing the RNG algorithm.
alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
strings are "philox" and "threefry".
Returns:
a 1-D tensor whose size depends on the algorithm.
An integer, unless the input is a Tensor in which case a Tensor is returned.
"""
return _make_state_from_seed(seed, algorithm)
if isinstance(alg, six.integer_types):
return alg
if isinstance(alg, Algorithm):
return alg.value
if isinstance(alg, ops.Tensor):
return alg
if isinstance(alg, str):
if alg == "philox":
return RNG_ALG_PHILOX
elif alg == "threefry":
return RNG_ALG_THREEFRY
else:
raise ValueError("Unknown algorithm name: %s" % alg)
else:
raise TypeError("Can't convert algorithm %s of type %s to int" %
(alg, type(alg)))
@tf_export("random.experimental.create_rng_state")
def create_rng_state(seed, alg):
"""Creates a RNG state from an integer or a vector.
Example:
>>> tf.random.experimental.create_rng_state(
... 1234, "philox")
array([1234, 0, 0])
>>> tf.random.experimental.create_rng_state(
... [12, 34], "threefry")
array([12, 34])
Args:
seed: an integer or 1-D numpy array.
alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.
Returns:
a 1-D numpy array whose size depends on the algorithm.
"""
alg = _convert_alg_to_int(alg)
return _make_state_from_seed(seed, alg)
def _shape_tensor(shape):
@ -239,12 +283,55 @@ def _create_variable(*args, **kwargs):
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
"""Random-number generator.
It uses Variable to manage its internal state, and allows choosing an
Random-Number-Generation (RNG) algorithm.
Example:
Creating a generator from a seed:
>>> g = tf.random.experimental.Generator.from_seed(1234)
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.9356609 , 1.0854305 , -0.93788373],
[-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)>
Creating a generator from a non-deterministic state:
>>> g = tf.random.experimental.Generator.from_non_deterministic_state()
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
All the constructors allow explicitly choosing an Random-Number-Generation
(RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
example:
>>> g = tf.random.experimental.Generator.from_seed(123, alg="philox")
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 0.8673864 , -0.29899067, -0.9310337 ],
[-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)>
CPU, GPU and TPU with the same algorithm and seed will generate the same
integer random numbers. Float-point results (such as the output of `normal`)
may have small numerical discrepancies between CPU and GPU.
may have small numerical discrepancies between different devices.
This class uses a `tf.Variable` to manage its internal state. Every time
random numbers are generated, the state of the generator will change. For
example:
>>> g = tf.random.experimental.Generator.from_seed(1234)
>>> g.state
<tf.Variable ... numpy=array([1234, 0, 0])>
>>> g.normal(shape=(2, 3))
<...>
>>> g.state
<tf.Variable ... numpy=array([2770, 0, 0])>
The shape of the state is algorithm-specific.
There is also a global generator:
>>> g = tf.random.experimental.get_global_generator()
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
"""
def __init__(self, copy_from=None, state=None, alg=None):
@ -263,11 +350,13 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
RNG, whose length and semantics are algorithm-specific. If it's a
variable, the generator will reuse it instead of creating a new
variable.
alg: the RNG algorithm. Possible values are `RNG_ALG_PHILOX` for the
Philox algorithm and `RNG_ALG_THREEFRY` for the ThreeFry
algorithm (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
alg: the RNG algorithm. Possible values are
`tf.random.experimental.Algorithm.PHILOX` for the Philox algorithm and
`tf.random.experimental.Algorithm.THREEFRY` for the ThreeFry algorithm
(see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
Note `RNG_ALG_PHILOX` guarantees the same numbers are produced (given
The string names `"philox"` and `"threefry"` can also be used.
Note `PHILOX` guarantees the same numbers are produced (given
the same random state) across all architextures (CPU, GPU, XLA etc).
Throws:
@ -287,6 +376,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
else:
assert alg is not None and state is not None
alg = _convert_alg_to_int(alg)
if isinstance(state, variables.Variable):
_check_state_shape(state.shape, alg)
self._state_var = state
@ -350,6 +440,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
if alg is None:
# TODO(wangpeng): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = _convert_alg_to_int(alg)
state = create_rng_state(seed, alg)
return cls(state=state, alg=alg)
@ -377,6 +468,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
if alg is None:
# TODO(wangpeng): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = _convert_alg_to_int(alg)
state = non_deterministic_ints(shape=[_get_state_size(alg)],
dtype=SEED_TYPE)
return cls(state=state, alg=alg)
@ -408,6 +500,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
"""
counter = _convert_to_state_tensor(counter)
key = _convert_to_state_tensor(key)
alg = _convert_alg_to_int(alg)
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
key.shape.assert_is_compatible_with([])
key = array_ops.reshape(key, [1])
@ -466,7 +559,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
@property
def algorithm(self):
"""The RNG algorithm."""
"""The RNG algorithm id (a Python integer or scalar integer Tensor)."""
return self._alg
def _standard_normal(self, shape, dtype):
@ -806,6 +899,16 @@ global_generator = None
@tf_export("random.experimental.get_global_generator")
def get_global_generator():
"""Retrieves the global generator.
This function will create the global generator the first time it is called,
and the generator will be placed at the default device at that time, so one
needs to be careful when this function is first called. Using a generator
placed on a less-ideal device will incur performance regression.
Returns:
The global `tf.random.experimental.Generator` object.
"""
global global_generator
if global_generator is None:
with ops.init_scope():

View File

@ -197,6 +197,17 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
check_results(expected_normal1, f(constructor))
check_results(expected_normal2, f(constructor))
@parameterized.parameters([
("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX),
("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)])
@test_util.run_v2_only
def testAlg(self, name, int_id, enum_id):
g_by_name = random.Generator.from_seed(1234, name)
g_by_int = random.Generator.from_seed(1234, int_id)
g_by_enum = random.Generator.from_seed(1234, enum_id)
self.assertEqual(g_by_name.algorithm, g_by_int.algorithm)
self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm)
@test_util.run_v2_only
def testGeneratorCreationWithVar(self):
"""Tests creating generator with a variable.

View File

@ -0,0 +1,12 @@
path: "tensorflow.random.experimental.Algorithm"
tf_class {
is_instance: "<enum \'Algorithm\'>"
member {
name: "PHILOX"
mtype: "<enum \'Algorithm\'>"
}
member {
name: "THREEFRY"
mtype: "<enum \'Algorithm\'>"
}
}

View File

@ -1,12 +1,16 @@
path: "tensorflow.random.experimental"
tf_module {
member {
name: "Algorithm"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "Generator"
mtype: "<type \'type\'>"
}
member_method {
name: "create_rng_state"
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_global_generator"

View File

@ -0,0 +1,12 @@
path: "tensorflow.random.experimental.Algorithm"
tf_class {
is_instance: "<enum \'Algorithm\'>"
member {
name: "PHILOX"
mtype: "<enum \'Algorithm\'>"
}
member {
name: "THREEFRY"
mtype: "<enum \'Algorithm\'>"
}
}

View File

@ -1,12 +1,16 @@
path: "tensorflow.random.experimental"
tf_module {
member {
name: "Algorithm"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "Generator"
mtype: "<type \'type\'>"
}
member_method {
name: "create_rng_state"
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'seed\', \'alg\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_global_generator"