Added alternative ways to initialize and reset a generator other than seed.

PiperOrigin-RevId: 248040749
This commit is contained in:
Peng Wang 2019-05-13 17:05:03 -07:00 committed by TensorFlower Gardener
parent c21eca030c
commit 8ebb2d418c
8 changed files with 376 additions and 155 deletions

View File

@ -67,7 +67,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def testSimple(self, alg):
"""A simple test."""
with ops.device(xla_device_name()):
gen = random.Generator(seed=0, algorithm=alg)
gen = random.Generator.from_seed(seed=0, alg=alg)
gen.normal(shape=(3,))
gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
gen.uniform_full_int(shape=(3,))
@ -77,7 +77,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def testDefun(self, alg):
"""Test for defun."""
with ops.device(xla_device_name()):
gen = random.Generator(seed=0, algorithm=alg)
gen = random.Generator.from_seed(seed=0, alg=alg)
@def_function.function
def f():
x = gen.normal(shape=(3,))
@ -86,7 +86,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
return (x, y, z)
f()
def _compareToKnownOutputs(self, counter, key, expect):
def _compareToKnownOutputs(self, g, counter, key, expect):
"""Compares against known outputs for specific counter and key inputs."""
def uint32s_to_uint64(a, b):
return b << 32 | a
@ -99,13 +99,11 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
counter = uint32s_to_uint64s(counter)
key = uint32s_to_uint64s(key)
state = counter + key
random.get_global_generator().reset(state)
got = random.get_global_generator().uniform_full_int(
shape=(ctr_len,), dtype=dtypes.uint32)
g.reset(state)
got = g.uniform_full_int(shape=(ctr_len,), dtype=dtypes.uint32)
self.assertAllEqual(expect, got)
random.get_global_generator().reset(state)
got = random.get_global_generator().uniform_full_int(
shape=(ctr_len // 2,), dtype=dtypes.uint64)
g.reset(state)
got = g.uniform_full_int(shape=(ctr_len // 2,), dtype=dtypes.uint64)
self.assertAllEqual(uint32s_to_uint64s(expect), got)
@test_util.run_v2_only
@ -118,14 +116,17 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
with ops.device(xla_device_name()):
random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
g = random.Generator.from_seed(seed=0, alg=random.RNG_ALG_THREEFRY)
self._compareToKnownOutputs(
g,
[0x00000000, 0x00000000], [0x00000000, 0x00000000],
[0x6b200159, 0x99ba4efe])
self._compareToKnownOutputs(
g,
[0xffffffff, 0xffffffff], [0xffffffff, 0xffffffff],
[0x1cb996fc, 0xbb002be7])
self._compareToKnownOutputs(
g,
[0x243f6a88, 0x85a308d3], [0x13198a2e, 0x03707344],
[0xc4923a9c, 0x483df7a0])
@ -137,16 +138,19 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_philox.cpp#L50-L52
with ops.device(xla_device_name()):
random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_PHILOX)
g = random.Generator.from_seed(seed=0, alg=random.RNG_ALG_PHILOX)
self._compareToKnownOutputs(
g,
[0x00000000, 0x00000000, 0x00000000, 0x00000000],
[0x00000000, 0x00000000],
[0x6627e8d5, 0xe169c58d, 0xbc57ac4c, 0x9b00dbd8])
self._compareToKnownOutputs(
g,
[0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff],
[0xffffffff, 0xffffffff],
[0x408f276d, 0x41c83b0e, 0xa20bc7c6, 0x6d5451fd])
self._compareToKnownOutputs(
g,
[0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344],
[0xa4093822, 0x299f31d0],
[0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1])
@ -159,12 +163,11 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
counter = 57
key = 0x1234
size = 46
seed = [counter, key]
gen = random.Generator(
seed=seed, algorithm=random.RNG_ALG_THREEFRY)
state = [counter, key]
gen = random.Generator(state=state, alg=random.RNG_ALG_THREEFRY)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value())
gen.reset(seed=seed)
gen.reset(state)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
self.assertAllEqual([counter+size, key], gen.state.read_value())
@ -177,13 +180,12 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
counter_high = 283
key = 0x1234
size = 47
seed = [counter_low, counter_high, key]
gen = random.Generator(
seed=seed, algorithm=random.RNG_ALG_PHILOX)
state = [counter_low, counter_high, key]
gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
self.assertAllEqual([counter_low+(size+3)//4, counter_high, key],
gen.state.read_value())
gen.reset(seed=seed)
gen.reset(state)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
self.assertAllEqual([counter_low+(size+1)//2, counter_high, key],
gen.state.read_value())
@ -191,13 +193,12 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
counter_low = -1 # same as 0xffffffffffffffff
counter_high = 283
size = 47
seed = [counter_low, counter_high, key]
gen = random.Generator(
seed=seed, algorithm=random.RNG_ALG_PHILOX)
state = [counter_low, counter_high, key]
gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
self.assertAllEqual([(size+3)//4-1, counter_high+1, key],
gen.state.read_value())
gen.reset(seed=seed)
gen.reset(state)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
self.assertAllEqual([(size+1)//2-1, counter_high+1, key],
gen.state.read_value())
@ -209,10 +210,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
seed = 1234
shape = [315, 49]
with ops.device("/device:CPU:0"):
cpu = (random.Generator(seed=seed, algorithm=random.RNG_ALG_PHILOX)
cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
.uniform_full_int(shape=shape, dtype=dtype))
with ops.device(xla_device_name()):
xla = (random.Generator(seed=seed, algorithm=random.RNG_ALG_PHILOX)
xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
.uniform_full_int(shape=shape, dtype=dtype))
self.assertAllEqual(cpu, xla)
@ -228,7 +229,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
@test_util.run_v2_only
def testUniformIsNotConstant(self, alg):
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=alg)
gen = random.Generator.from_seed(seed=1234, alg=alg)
def rng(dtype):
maxval = dtype.max
# Workaround for b/125364959
@ -243,7 +244,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
@test_util.run_v2_only
def testNormalIsNotConstant(self, alg):
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=alg)
gen = random.Generator.from_seed(seed=1234, alg=alg)
def rng(dtype):
return gen.normal(shape=[2], dtype=dtype)
@ -258,7 +259,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
size = 1000
with ops.device(xla_device_name()):
for dtype in self._ints + self._floats:
gen = random.Generator(seed=1234, algorithm=alg)
gen = random.Generator.from_seed(seed=1234, alg=alg)
x = gen.uniform(
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
self.assertTrue(np.all(x >= minval))
@ -268,7 +269,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
@test_util.run_v2_only
def testNormalIsFinite(self, alg):
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=alg)
gen = random.Generator.from_seed(seed=1234, alg=alg)
for dtype in self._floats:
x = gen.normal(shape=[10000], dtype=dtype).numpy()
self.assertTrue(np.all(np.isfinite(x)))
@ -281,7 +282,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
n = 1000
seed = 12
for dtype in self._ints + self._floats:
gen = random.Generator(seed=seed, algorithm=alg)
gen = random.Generator.from_seed(seed=seed, alg=alg)
maxval = 1
if dtype.is_integer:
maxval = 100
@ -303,7 +304,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
with ops.device(xla_device_name()):
n = 1000
for dtype in self._floats:
gen = random.Generator(seed=1234, algorithm=alg)
gen = random.Generator.from_seed(seed=1234, alg=alg)
x = gen.normal(shape=[n], dtype=dtype).numpy()
# The constant 2.492 is the 5% critical value for the Anderson-Darling
# test where the mean and variance are known. This test is probabilistic
@ -316,7 +317,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def testTruncatedNormal(self, alg):
with ops.device(xla_device_name()):
for dtype in self._floats:
gen = random.Generator(seed=123, algorithm=alg)
gen = random.Generator.from_seed(seed=123, alg=alg)
n = 10000000
y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
random_test_util.test_truncated_normal(
@ -328,7 +329,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
"""
shape = [2, 3]
with ops.device(xla_device_name()):
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
gen = random.Generator.from_seed(seed=1234, alg=random.RNG_ALG_THREEFRY)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
r"algorithm must be of shape \[\], not"):

View File

@ -38,7 +38,7 @@ class RandomBinomialTest(test.TestCase):
def _Sampler(self, num, counts, probs, dtype, seed=None):
def func():
rng = stateful_random_ops.Generator(seed=seed).binomial(
rng = stateful_random_ops.Generator.from_seed(seed).binomial(
shape=[10 * num], counts=counts, probs=probs, dtype=dtype)
ret = array_ops.reshape(rng, [10, num])
ret = self.evaluate(ret)
@ -80,11 +80,11 @@ class RandomBinomialTest(test.TestCase):
self.assertAllEqual(sx(), sy())
def testZeroShape(self):
rnd = stateful_random_ops.Generator(seed=12345).binomial([0], [], [])
rnd = stateful_random_ops.Generator.from_seed(12345).binomial([0], [], [])
self.assertEqual([0], rnd.shape.as_list())
def testShape(self):
rng = stateful_random_ops.Generator(seed=12345)
rng = stateful_random_ops.Generator.from_seed(12345)
# Scalar parameters.
rnd = rng.binomial(shape=[10], counts=np.float32(2.), probs=np.float32(0.5))
self.assertEqual([10], rnd.shape.as_list())

View File

@ -110,11 +110,15 @@ def _make_1d_state(state_size, seed):
raise ValueError(
"seed should only have one dimension; got shape: %s" % seed.shape)
seed = seed[0:state_size]
# Padding with zeros on the right if too short
# Padding with zeros on the *left* if too short. Padding on the right would
# cause a small seed to be used as the "counter" while the "key" is always
# zero (for counter-based RNG algorithms), because in the current memory
# layout counter is stored before key. In such a situation two RNGs with
# two different small seeds may generate overlapping outputs.
seed_size = seed.shape[0]
if seed_size < state_size:
seed = np.pad(
seed, [(0, state_size - seed_size)],
seed, [(state_size - seed_size, 0)],
mode="constant",
constant_values=0)
assert seed.shape == (state_size,), "Wrong seed.shape: %s" % seed.shape
@ -157,6 +161,13 @@ def _shape_tensor(shape):
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
def _convert_to_state_tensor(t):
if isinstance(t, list):
# to avoid out-of-range error from ops.convert_to_tensor
t = list(map(_uint_to_int, t))
return ops.convert_to_tensor(t, dtype=STATE_TYPE)
@tf_export("random.experimental.Generator")
class Generator(tracking.AutoTrackable):
"""Random-number generator.
@ -164,57 +175,171 @@ class Generator(tracking.AutoTrackable):
It uses Variable to manage its internal state, and allows choosing an
Random-Number-Generation (RNG) algorithm.
CPU and GPU with the same algorithm and seed will generate the same integer
random numbers. Float-point results (such as the output of `normal`) may have
small numerical discrepancies between CPU and GPU.
Because of different counter-reservation schemes, TPU's integer random numbers
will be different from CPU/GPU even with the same algorithm and seed.
Also, TPU uses different sampling algorithms for some distributions
(e.g. using reverse CDF for sampling normal distribution instead of
Box-Muller used by CPU/GPU). Harmonizing TPU's RNG behavior with CPU/GPU is
work in progress.
CPU, GPU and TPU with the same algorithm and seed will generate the same
integer random numbers. Float-point results (such as the output of `normal`)
may have small numerical discrepancies between CPU and GPU.
"""
def __init__(self, copy_from=None, seed=None, algorithm=None):
def __init__(self, copy_from=None, state=None, alg=None):
"""Creates a generator.
The new generator will be initialized by one of the following ways, with
decreasing precedence:
(1) If `copy_from` is not None, the new generator is initialized by copying
information from another generator.
(3) If `state` and `alg` are not None (they must be set together), the new
generator is initialized by a state.
Args:
copy_from: (optional) a generator to be copied from.
seed: (optional) the seed for the RNG. If None, it will be chosen
nondeterministically
algorithm: (optional) the RNG algorithm. If None, it will be
auto-selected.
copy_from: a generator to be copied from.
state: a vector of dtype STATE_TYPE representing the initial state of the
RNG, whose length and semantics are algorithm-specific.
alg: the RNG algorithm. Possible values are RNG_ALG_PHILOX for the
Philox algorithm and RNG_ALG_THREEFRY for the ThreeFry
algorithm (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
"""
if copy_from is None:
if algorithm is None:
# TODO(wangpeng): more sophisticated algorithm selection
algorithm = DEFAULT_ALGORITHM
if seed is None:
state = non_deterministic_ints(shape=[_get_state_size(algorithm)],
dtype=SEED_TYPE)
else:
state = create_rng_state(seed, algorithm)
self._state_var = variables.Variable(state,
dtype=STATE_TYPE,
trainable=False)
self._alg_var = algorithm
else:
assert seed is None
self._state_var = variables.Variable(copy_from.state,
dtype=STATE_TYPE,
if copy_from is not None:
# All other arguments should be None
assert (alg or state) is None
self._state_var = variables.Variable(copy_from.state, dtype=STATE_TYPE,
trainable=False)
self._alg_var = copy_from.algorithm
def reset(self, seed):
"""Resets the generator.
else:
assert alg is not None and state is not None
state = _convert_to_state_tensor(state)
state.shape.assert_is_compatible_with([_get_state_size(alg)])
self._state_var = variables.Variable(state, dtype=STATE_TYPE,
trainable=False)
self._alg_var = alg
@classmethod
def from_state(cls, state, alg):
"""Creates a generator from a state.
See `__init__` for description of `state` and `alg`.
Args:
seed: the seed to reset the RNG to.
state: the new state.
alg: the RNG algorithm.
Returns:
The new generator.
"""
return cls(alg=alg, state=state)
@classmethod
def from_seed(cls, seed, alg=None):
"""Creates a generator from a seed.
A seed is a 1024-bit unsigned integer represented either as a Python
integer or a vector of integers. Seeds shorter than 1024-bit will be
padded. The padding, the internal structure of a seed and the way a seed
is converted to a state are all opaque (unspecified). The only semantics
specification of seeds is that two different seeds are likely to produce
two independent generators (but no guarantee).
Args:
seed: the seed for the RNG.
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
if alg is None:
# TODO(wangpeng): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
state = create_rng_state(seed, alg)
return cls(state=state, alg=alg)
@classmethod
def from_non_deterministic_state(cls, alg=None):
"""Creates a generator by non-deterministically initializing its state.
The source of the non-determinism will be platform- and time-dependent.
Args:
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
if alg is None:
# TODO(wangpeng): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
state = non_deterministic_ints(shape=[_get_state_size(alg)],
dtype=SEED_TYPE)
return cls(state=state, alg=alg)
@classmethod
def from_key_counter(cls, key, counter, alg):
"""Creates a generator from a key and a counter.
This constructor only applies if the algorithm is a counter-based algorithm.
See method `key` for the meaning of "key" and "counter".
Args:
key: the key for the RNG, a scalar of type STATE_TYPE.
counter: a vector of dtype STATE_TYPE representing the initial counter for
the RNG, whose length is algorithm-specific.,
alg: the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
counter = _convert_to_state_tensor(counter)
key = _convert_to_state_tensor(key)
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
key.shape.assert_is_compatible_with([])
key = array_ops.reshape(key, [1])
state = array_ops.concat([counter, key], 0)
return cls(state=state, alg=alg)
def reset(self, state):
"""Resets the generator by a new state.
See `__init__` for the meaning of "state".
Args:
state: the new state.
"""
state = _convert_to_state_tensor(state)
state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)])
self._state_var.assign(state)
def reset_from_seed(self, seed):
"""Resets the generator by a new seed.
See `from_seed` for the meaning of "seed".
Args:
seed: the new seed.
"""
state = create_rng_state(seed, self.algorithm)
self._state_var.assign(state)
def reset_from_key_counter(self, key, counter):
"""Resets the generator by a new key-counter pair.
See `from_key_counter` for the meaning of "key" and "counter".
Args:
key: the new key.
counter: the new counter.
"""
counter = _convert_to_state_tensor(counter)
key = _convert_to_state_tensor(key)
counter.shape.assert_is_compatible_with(
[_get_state_size(self.algorithm) - 1])
key.shape.assert_is_compatible_with([])
key = array_ops.reshape(key, [1])
state = array_ops.concat([counter, key], 0)
self._state_var.assign(state)
@property
def state(self):
"""The internal state of the RNG."""
@ -235,7 +360,7 @@ class Generator(tracking.AutoTrackable):
For a counter-base RNG algorithm such as Philox and ThreeFry (as
described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
(https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)),
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]),
the RNG state consists of two parts: counter and key. The output is
generated via the formula: output=hash(key, counter), i.e. a hashing of
the counter parametrized by the key. Two RNGs with two different keys can
@ -535,7 +660,7 @@ class Generator(tracking.AutoTrackable):
alg = self.algorithm
if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
keys = self._make_int64_keys(shape=[count])
return [Generator(seed=_key_to_state(alg, key), algorithm=alg)
return [Generator(state=_key_to_state(alg, key), alg=alg)
for key in keys.numpy()]
else:
raise ValueError("Unsupported algorithm id: %s" % alg)
@ -551,31 +676,27 @@ global_generator = None
def get_global_generator():
global global_generator
if global_generator is None:
global_generator = Generator()
global_generator = Generator.from_non_deterministic_state()
return global_generator
@tf_export("random.experimental.set_global_generator")
def set_global_generator(generator):
"""Replaces the global generator with another `Generator` object.
This function creates a new Generator object (and the Variable object within),
which does not work well with tf.function because (1) tf.function puts
restrictions on Variable creation thus reset_global_generator can't be freely
used inside tf.function; (2) redirecting a global variable to
a new object is problematic with tf.function because the old object may be
captured by a 'tf.function'ed function and still be used by it.
A 'tf.function'ed function only keeps weak references to variables,
so deleting a variable and then calling that function again may raise an
error, as demonstrated by
random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun .
Args:
generator: the new `Generator` object.
"""
global global_generator
global_generator = generator
# This function creates a new Generator object (and the Variable object within),
# which does not work well with tf.function because (1) tf.function puts
# restrictions on Variable creation thus reset_global_generator can't be freely
# used inside tf.function; (2) redirecting a global variable to
# a new object is problematic with tf.function because the old object may be
# captured by a 'tf.function'ed function and still be used by it.
# A 'tf.function'ed function only keeps weak references to variables,
# so deleting a variable and then calling that function again may raise an
# error, as demonstrated by
# random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun .
# The function 'set_global_generator' below also has this problem.
@tf_export("random.experimental.reset_global_generator")
def reset_global_generator(seed, algorithm=None):
global global_generator
if algorithm is None:
# preserve the old algorithm
algorithm = int(get_global_generator().algorithm)
global_generator = Generator(seed=seed, algorithm=algorithm)

View File

@ -96,7 +96,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
"""
shape = [2, 3]
count = 6
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
keys1 = gen._make_int64_keys(shape=shape)
keys2 = gen._make_int64_keys(shape=shape)
self.assertAllDifferent([keys1, keys2])
@ -126,42 +126,100 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
"""Tests that a CPU RNG can split into RNGs on GPU.
"""
with ops.device("/device:CPU:0"):
gen = random.Generator(seed=1234) # gen is on CPU
gen = random.Generator.from_seed(1234) # gen is on CPU
self.assertRegex("CPU", gen.state.device)
with ops.device(test_util.gpu_device_name()):
gens = gen.split(count=10) # gens are on GPU
self.assertRegex("GPU", gens[0].state.device)
@test_util.run_v2_only
def testGeneratorCreationInDefun(self):
"""Tests creating a Generator in defun.
def testReset(self):
shape = [2, 3]
gen = random.Generator.from_seed(0)
for resetter in [
lambda g: g.reset(state=[1, 2, 3]),
lambda g: g.reset_from_seed(1234),
lambda g: g.reset_from_key_counter(key=1, counter=[2, 3]),
]:
resetter(gen)
expected_normal = gen.normal(shape)
@def_function.function
def f(resetter):
resetter(gen)
return gen.normal(shape)
def check_results(expected_normal, v):
self.assertAllEqual(expected_normal, v)
check_results(expected_normal, f(resetter))
check_results(expected_normal, f(resetter))
@test_util.run_v2_only
def testGeneratorCreation(self):
"""Tests generator creation, in both eager and tf.function.
The interaction between Generator creation and defun should be the same as
tf.Variable.
"""
seed = 1234
shape = [2, 3]
with ops.device("/device:CPU:0"):
gen = random.Generator(seed=seed)
alg = random.RNG_ALG_PHILOX
for constructor in [
lambda: random.Generator(state=[1, 2, 3], alg=alg),
lambda: random.Generator.from_seed(1234),
lambda: random.Generator.from_key_counter( # pylint: disable=g-long-lambda
key=1, counter=[2, 3], alg=alg),
]:
gen = constructor()
# Tests tf.function
expected_normal1 = gen.normal(shape)
expected_normal2 = gen.normal(shape)
global g_seeded
g_seeded = None
@def_function.function
def f():
def f(constructor):
global g_seeded
global g_unseeded
# defun'ed function should only create variables once
if g_seeded is None:
g_seeded = random.Generator(seed=seed)
if g_unseeded is None:
g_unseeded = random.Generator()
r = g_seeded.normal(shape)
r = (r, g_unseeded.normal(shape))
return r
def check_results(expected_normal, v1, v2):
self.assertAllEqual(expected_normal, v1)
self.assertAllEqual(shape, v2.shape)
check_results(expected_normal1, *f())
check_results(expected_normal2, *f())
g_seeded = constructor()
return g_seeded.normal(shape)
def check_results(expected_normal, v):
self.assertAllEqual(expected_normal, v)
check_results(expected_normal1, f(constructor))
check_results(expected_normal2, f(constructor))
@test_util.run_v2_only
def testGeneratorCreationUnseeded(self):
"""Tests generator creation, the unseeded case."""
shape = [2, 3]
global g_unseeded
g_unseeded = None
@def_function.function
def f():
global g_unseeded
# defun'ed function should only create variables once
if g_unseeded is None:
g_unseeded = random.Generator.from_non_deterministic_state()
return g_unseeded.normal(shape)
self.assertAllEqual(shape, f().shape)
@test_util.run_v2_only
def testGeneratorCopy(self):
"""Tests copying a generator."""
g = random.Generator.from_seed(0)
g_copy = random.Generator(g)
self.assertAllEqual(g.algorithm, g_copy.algorithm)
self.assertAllEqual(g.state.read_value(), g_copy.state.read_value())
# Tests tf.function
global g_seeded
g_seeded = None
# Do the same in tf.function
@def_function.function
def f():
global g_seeded
# defun'ed function should only create variables once
if g_seeded is None:
g_seeded = random.Generator(g)
self.assertAllEqual(g.algorithm, g_seeded.algorithm)
self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value())
f()
@test_util.run_v1_only(
("This test is specifically for checking TF1 compatibility. "
@ -176,8 +234,8 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
[[-0.3964749, 0.8369565, -0.30946946],
[1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32)
with self.cached_session() as sess:
gen1 = random.Generator(seed=seed)
gen2 = random.Generator()
gen1 = random.Generator.from_seed(seed)
gen2 = random.Generator.from_non_deterministic_state()
sess.run((gen1._state_var.initializer, gen2._state_var.initializer))
r1 = gen1.normal(shape, dtype=dtypes.float32)
r2 = gen2.normal(shape, dtype=dtypes.float32)
@ -203,9 +261,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
https://github.com/tensorflow/tensorflow/issues/9171
"""
shape = (3,)
random.get_global_generator().reset(1)
random.get_global_generator().reset_from_seed(1)
a = random.get_global_generator().normal(shape)
random.get_global_generator().reset(1)
random.get_global_generator().reset_from_seed(1)
b = random.get_global_generator().normal(shape)
self.assertAllEqual(a, b)
@ -214,9 +272,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
def f():
return random.get_global_generator().normal(shape)
random.get_global_generator().reset(1)
random.get_global_generator().reset_from_seed(1)
c = f()
random.get_global_generator().reset(1)
random.get_global_generator().reset_from_seed(1)
d = f()
self.assertAllEqual(c, d)
self.assertAllEqual(a, c)
@ -237,9 +295,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
return random.get_global_generator().normal(shape)
def compare(fst_includes_print, snd_includes_print):
random.get_global_generator().reset(50)
random.get_global_generator().reset_from_seed(50)
fst = f(fst_includes_print)
random.get_global_generator().reset(50)
random.get_global_generator().reset_from_seed(50)
snd = f(snd_includes_print)
self.assertAllEqual(fst, snd)
# Now do the above again using accelerated (defunned) 'f'.
@ -247,9 +305,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
# two different graphs to be generated, hence demonstrating the
# insensitivity to graph changes.
f_acc = def_function.function(f)
random.get_global_generator().reset(50)
random.get_global_generator().reset_from_seed(50)
fst = f_acc(fst_includes_print)
random.get_global_generator().reset(50)
random.get_global_generator().reset_from_seed(50)
snd = f_acc(snd_includes_print)
self.assertAllEqual(fst, snd)
@ -260,7 +318,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
@test_util.run_v2_only
def testKey(self):
key = 1234
gen = random.Generator(seed=[0, 0, key])
gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX)
got = gen.key
self.assertAllEqual(key, got)
@def_function.function
@ -273,7 +331,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
def testSkip(self):
key = 1234
counter = 5678
gen = random.Generator(seed=[counter, 0, key])
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
delta = 432
gen.skip(delta)
new_counter = gen._state_var[0]
@ -285,7 +343,8 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
# note how the two seeds for the old op correspond to the seed for the new
# op
with ops.device(device):
gen = random.Generator(seed=[0, seed2, seed1])
gen = random.Generator(state=[0, seed2, seed1],
alg=random.RNG_ALG_PHILOX)
# create a graph for the old op in order to call it many times
@def_function.function
@ -361,10 +420,10 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
seed = 1234
shape = [315, 49]
with ops.device("/device:CPU:0"):
cpu = random.Generator(seed=seed).uniform_full_int(
cpu = random.Generator.from_seed(seed).uniform_full_int(
shape=shape, dtype=dtype)
with ops.device(test_util.gpu_device_name()):
gpu = random.Generator(seed=seed).uniform_full_int(
gpu = random.Generator.from_seed(seed).uniform_full_int(
shape=shape, dtype=dtype)
self.assertAllEqual(cpu, gpu)
@ -374,7 +433,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
minval = 2
maxval = 33
size = 1000
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
x = gen.uniform(
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
self.assertTrue(np.all(x >= minval))
@ -383,7 +442,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(FLOATS)
@test_util.run_v2_only
def testNormalIsFinite(self, dtype):
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
x = gen.normal(shape=[10000], dtype=dtype).numpy()
self.assertTrue(np.all(np.isfinite(x)))
@ -393,7 +452,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
"""Use Pearson's Chi-squared test to test for uniformity."""
n = 1000
seed = 12
gen = random.Generator(seed=seed)
gen = random.Generator.from_seed(seed)
maxval = 1
if dtype.is_integer:
maxval = 100
@ -413,7 +472,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
def testDistributionOfNormal(self, dtype):
"""Use Anderson-Darling test to test distribution appears normal."""
n = 1000
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
x = gen.normal(shape=[n], dtype=dtype).numpy()
# The constant 2.492 is the 5% critical value for the Anderson-Darling
# test where the mean and variance are known. This test is probabilistic
@ -426,7 +485,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
"""Tests that proper errors are raised.
"""
shape = [2, 3]
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"must have shape \[\], not"):
@ -466,8 +525,8 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
var.handle, random.RNG_ALG_PHILOX, shape)
@test_util.run_v2_only
def testResetGlobalGeneratorBadWithDefun(self):
"""Demonstrates that reset_global_generator don't work properly with defun.
def testSetGlobalGeneratorBadWithDefun(self):
"""Demonstrates that set_global_generator don't work properly with defun.
"""
shape = (3,)
@ -475,11 +534,11 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
def f():
return random.get_global_generator().normal(shape)
random.reset_global_generator(50)
random.set_global_generator(random.Generator.from_seed(50))
with self.assertRaisesWithPredicateMatch(
errors.NotFoundError, "Resource .+ does not exist"):
_ = f()
random.reset_global_generator(50)
random.set_global_generator(random.Generator.from_seed(50))
_ = f()
@test_util.run_v2_only
@ -492,7 +551,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
"""
shape = [3, 4]
dtype = dtypes.int32
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
with strat.scope():
def f():
@ -519,7 +578,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
dtype = dtypes.int32
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
with strat.scope():
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
def f():
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
@ -545,7 +604,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
dtype = dtypes.int32
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
def f():
gen = random.Generator(seed=1234)
gen = random.Generator.from_seed(1234)
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
t = array_ops.stack([t1, t2])
@ -575,7 +634,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
# different random-number stream. The only obstacle is that op
# 'NonDeterministicInts' is not implemented on GPU.)
with strat.scope():
gen = random.Generator()
gen = random.Generator.from_non_deterministic_state()
def f():
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)

View File

@ -18,12 +18,28 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'copy_from\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "binomial"
argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
name: "from_key_counter"
argspec: "args=[\'cls\', \'key\', \'counter\', \'alg\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_non_deterministic_state"
argspec: "args=[\'cls\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_seed"
argspec: "args=[\'cls\', \'seed\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_state"
argspec: "args=[\'cls\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_seeds"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
@ -34,6 +50,14 @@ tf_class {
}
member_method {
name: "reset"
argspec: "args=[\'self\', \'state\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reset_from_key_counter"
argspec: "args=[\'self\', \'key\', \'counter\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reset_from_seed"
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=None"
}
member_method {

View File

@ -12,10 +12,6 @@ tf_module {
name: "get_global_generator"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reset_global_generator"
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_global_generator"
argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None"

View File

@ -18,12 +18,28 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'copy_from\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "binomial"
argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
name: "from_key_counter"
argspec: "args=[\'cls\', \'key\', \'counter\', \'alg\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_non_deterministic_state"
argspec: "args=[\'cls\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_seed"
argspec: "args=[\'cls\', \'seed\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "from_state"
argspec: "args=[\'cls\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_seeds"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
@ -34,6 +50,14 @@ tf_class {
}
member_method {
name: "reset"
argspec: "args=[\'self\', \'state\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reset_from_key_counter"
argspec: "args=[\'self\', \'key\', \'counter\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reset_from_seed"
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=None"
}
member_method {

View File

@ -12,10 +12,6 @@ tf_module {
name: "get_global_generator"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reset_global_generator"
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_global_generator"
argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None"