Added alternative ways to initialize and reset a generator other than seed
.
PiperOrigin-RevId: 248040749
This commit is contained in:
parent
c21eca030c
commit
8ebb2d418c
tensorflow
compiler/tests
python
tools/api/golden
@ -67,7 +67,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testSimple(self, alg):
|
||||
"""A simple test."""
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=0, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=0, alg=alg)
|
||||
gen.normal(shape=(3,))
|
||||
gen.uniform(shape=(3,), minval=0, maxval=10, dtype=dtypes.uint32)
|
||||
gen.uniform_full_int(shape=(3,))
|
||||
@ -77,7 +77,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testDefun(self, alg):
|
||||
"""Test for defun."""
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=0, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=0, alg=alg)
|
||||
@def_function.function
|
||||
def f():
|
||||
x = gen.normal(shape=(3,))
|
||||
@ -86,7 +86,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
return (x, y, z)
|
||||
f()
|
||||
|
||||
def _compareToKnownOutputs(self, counter, key, expect):
|
||||
def _compareToKnownOutputs(self, g, counter, key, expect):
|
||||
"""Compares against known outputs for specific counter and key inputs."""
|
||||
def uint32s_to_uint64(a, b):
|
||||
return b << 32 | a
|
||||
@ -99,13 +99,11 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
counter = uint32s_to_uint64s(counter)
|
||||
key = uint32s_to_uint64s(key)
|
||||
state = counter + key
|
||||
random.get_global_generator().reset(state)
|
||||
got = random.get_global_generator().uniform_full_int(
|
||||
shape=(ctr_len,), dtype=dtypes.uint32)
|
||||
g.reset(state)
|
||||
got = g.uniform_full_int(shape=(ctr_len,), dtype=dtypes.uint32)
|
||||
self.assertAllEqual(expect, got)
|
||||
random.get_global_generator().reset(state)
|
||||
got = random.get_global_generator().uniform_full_int(
|
||||
shape=(ctr_len // 2,), dtype=dtypes.uint64)
|
||||
g.reset(state)
|
||||
got = g.uniform_full_int(shape=(ctr_len // 2,), dtype=dtypes.uint64)
|
||||
self.assertAllEqual(uint32s_to_uint64s(expect), got)
|
||||
|
||||
@test_util.run_v2_only
|
||||
@ -118,14 +116,17 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_threefry.cpp#L30-L32
|
||||
|
||||
with ops.device(xla_device_name()):
|
||||
random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_THREEFRY)
|
||||
g = random.Generator.from_seed(seed=0, alg=random.RNG_ALG_THREEFRY)
|
||||
self._compareToKnownOutputs(
|
||||
g,
|
||||
[0x00000000, 0x00000000], [0x00000000, 0x00000000],
|
||||
[0x6b200159, 0x99ba4efe])
|
||||
self._compareToKnownOutputs(
|
||||
g,
|
||||
[0xffffffff, 0xffffffff], [0xffffffff, 0xffffffff],
|
||||
[0x1cb996fc, 0xbb002be7])
|
||||
self._compareToKnownOutputs(
|
||||
g,
|
||||
[0x243f6a88, 0x85a308d3], [0x13198a2e, 0x03707344],
|
||||
[0xc4923a9c, 0x483df7a0])
|
||||
|
||||
@ -137,16 +138,19 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
# https://github.com/DEShawResearch/Random123-Boost/blob/65e3d874b67aa7b3e02d5ad8306462f52d2079c0/libs/random/test/test_philox.cpp#L50-L52
|
||||
|
||||
with ops.device(xla_device_name()):
|
||||
random.reset_global_generator(seed=0, algorithm=random.RNG_ALG_PHILOX)
|
||||
g = random.Generator.from_seed(seed=0, alg=random.RNG_ALG_PHILOX)
|
||||
self._compareToKnownOutputs(
|
||||
g,
|
||||
[0x00000000, 0x00000000, 0x00000000, 0x00000000],
|
||||
[0x00000000, 0x00000000],
|
||||
[0x6627e8d5, 0xe169c58d, 0xbc57ac4c, 0x9b00dbd8])
|
||||
self._compareToKnownOutputs(
|
||||
g,
|
||||
[0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff],
|
||||
[0xffffffff, 0xffffffff],
|
||||
[0x408f276d, 0x41c83b0e, 0xa20bc7c6, 0x6d5451fd])
|
||||
self._compareToKnownOutputs(
|
||||
g,
|
||||
[0x243f6a88, 0x85a308d3, 0x13198a2e, 0x03707344],
|
||||
[0xa4093822, 0x299f31d0],
|
||||
[0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1])
|
||||
@ -159,12 +163,11 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
counter = 57
|
||||
key = 0x1234
|
||||
size = 46
|
||||
seed = [counter, key]
|
||||
gen = random.Generator(
|
||||
seed=seed, algorithm=random.RNG_ALG_THREEFRY)
|
||||
state = [counter, key]
|
||||
gen = random.Generator(state=state, alg=random.RNG_ALG_THREEFRY)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
|
||||
self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value())
|
||||
gen.reset(seed=seed)
|
||||
gen.reset(state)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
|
||||
self.assertAllEqual([counter+size, key], gen.state.read_value())
|
||||
|
||||
@ -177,13 +180,12 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
counter_high = 283
|
||||
key = 0x1234
|
||||
size = 47
|
||||
seed = [counter_low, counter_high, key]
|
||||
gen = random.Generator(
|
||||
seed=seed, algorithm=random.RNG_ALG_PHILOX)
|
||||
state = [counter_low, counter_high, key]
|
||||
gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
|
||||
self.assertAllEqual([counter_low+(size+3)//4, counter_high, key],
|
||||
gen.state.read_value())
|
||||
gen.reset(seed=seed)
|
||||
gen.reset(state)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
|
||||
self.assertAllEqual([counter_low+(size+1)//2, counter_high, key],
|
||||
gen.state.read_value())
|
||||
@ -191,13 +193,12 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
counter_low = -1 # same as 0xffffffffffffffff
|
||||
counter_high = 283
|
||||
size = 47
|
||||
seed = [counter_low, counter_high, key]
|
||||
gen = random.Generator(
|
||||
seed=seed, algorithm=random.RNG_ALG_PHILOX)
|
||||
state = [counter_low, counter_high, key]
|
||||
gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
|
||||
self.assertAllEqual([(size+3)//4-1, counter_high+1, key],
|
||||
gen.state.read_value())
|
||||
gen.reset(seed=seed)
|
||||
gen.reset(state)
|
||||
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
|
||||
self.assertAllEqual([(size+1)//2-1, counter_high+1, key],
|
||||
gen.state.read_value())
|
||||
@ -209,10 +210,10 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
seed = 1234
|
||||
shape = [315, 49]
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = (random.Generator(seed=seed, algorithm=random.RNG_ALG_PHILOX)
|
||||
cpu = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
with ops.device(xla_device_name()):
|
||||
xla = (random.Generator(seed=seed, algorithm=random.RNG_ALG_PHILOX)
|
||||
xla = (random.Generator.from_seed(seed=seed, alg=random.RNG_ALG_PHILOX)
|
||||
.uniform_full_int(shape=shape, dtype=dtype))
|
||||
self.assertAllEqual(cpu, xla)
|
||||
|
||||
@ -228,7 +229,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
@test_util.run_v2_only
|
||||
def testUniformIsNotConstant(self, alg):
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||
def rng(dtype):
|
||||
maxval = dtype.max
|
||||
# Workaround for b/125364959
|
||||
@ -243,7 +244,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
@test_util.run_v2_only
|
||||
def testNormalIsNotConstant(self, alg):
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||
def rng(dtype):
|
||||
return gen.normal(shape=[2], dtype=dtype)
|
||||
|
||||
@ -258,7 +259,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
size = 1000
|
||||
with ops.device(xla_device_name()):
|
||||
for dtype in self._ints + self._floats:
|
||||
gen = random.Generator(seed=1234, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||
x = gen.uniform(
|
||||
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
||||
self.assertTrue(np.all(x >= minval))
|
||||
@ -268,7 +269,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
@test_util.run_v2_only
|
||||
def testNormalIsFinite(self, alg):
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||
for dtype in self._floats:
|
||||
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||
self.assertTrue(np.all(np.isfinite(x)))
|
||||
@ -281,7 +282,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
n = 1000
|
||||
seed = 12
|
||||
for dtype in self._ints + self._floats:
|
||||
gen = random.Generator(seed=seed, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=seed, alg=alg)
|
||||
maxval = 1
|
||||
if dtype.is_integer:
|
||||
maxval = 100
|
||||
@ -303,7 +304,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
with ops.device(xla_device_name()):
|
||||
n = 1000
|
||||
for dtype in self._floats:
|
||||
gen = random.Generator(seed=1234, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=1234, alg=alg)
|
||||
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
||||
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||
# test where the mean and variance are known. This test is probabilistic
|
||||
@ -316,7 +317,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
def testTruncatedNormal(self, alg):
|
||||
with ops.device(xla_device_name()):
|
||||
for dtype in self._floats:
|
||||
gen = random.Generator(seed=123, algorithm=alg)
|
||||
gen = random.Generator.from_seed(seed=123, alg=alg)
|
||||
n = 10000000
|
||||
y = gen.truncated_normal(shape=[n], dtype=dtype).numpy()
|
||||
random_test_util.test_truncated_normal(
|
||||
@ -328,7 +329,7 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
"""
|
||||
shape = [2, 3]
|
||||
with ops.device(xla_device_name()):
|
||||
gen = random.Generator(seed=1234, algorithm=random.RNG_ALG_THREEFRY)
|
||||
gen = random.Generator.from_seed(seed=1234, alg=random.RNG_ALG_THREEFRY)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError,
|
||||
r"algorithm must be of shape \[\], not"):
|
||||
|
@ -38,7 +38,7 @@ class RandomBinomialTest(test.TestCase):
|
||||
def _Sampler(self, num, counts, probs, dtype, seed=None):
|
||||
|
||||
def func():
|
||||
rng = stateful_random_ops.Generator(seed=seed).binomial(
|
||||
rng = stateful_random_ops.Generator.from_seed(seed).binomial(
|
||||
shape=[10 * num], counts=counts, probs=probs, dtype=dtype)
|
||||
ret = array_ops.reshape(rng, [10, num])
|
||||
ret = self.evaluate(ret)
|
||||
@ -80,11 +80,11 @@ class RandomBinomialTest(test.TestCase):
|
||||
self.assertAllEqual(sx(), sy())
|
||||
|
||||
def testZeroShape(self):
|
||||
rnd = stateful_random_ops.Generator(seed=12345).binomial([0], [], [])
|
||||
rnd = stateful_random_ops.Generator.from_seed(12345).binomial([0], [], [])
|
||||
self.assertEqual([0], rnd.shape.as_list())
|
||||
|
||||
def testShape(self):
|
||||
rng = stateful_random_ops.Generator(seed=12345)
|
||||
rng = stateful_random_ops.Generator.from_seed(12345)
|
||||
# Scalar parameters.
|
||||
rnd = rng.binomial(shape=[10], counts=np.float32(2.), probs=np.float32(0.5))
|
||||
self.assertEqual([10], rnd.shape.as_list())
|
||||
|
@ -110,11 +110,15 @@ def _make_1d_state(state_size, seed):
|
||||
raise ValueError(
|
||||
"seed should only have one dimension; got shape: %s" % seed.shape)
|
||||
seed = seed[0:state_size]
|
||||
# Padding with zeros on the right if too short
|
||||
# Padding with zeros on the *left* if too short. Padding on the right would
|
||||
# cause a small seed to be used as the "counter" while the "key" is always
|
||||
# zero (for counter-based RNG algorithms), because in the current memory
|
||||
# layout counter is stored before key. In such a situation two RNGs with
|
||||
# two different small seeds may generate overlapping outputs.
|
||||
seed_size = seed.shape[0]
|
||||
if seed_size < state_size:
|
||||
seed = np.pad(
|
||||
seed, [(0, state_size - seed_size)],
|
||||
seed, [(state_size - seed_size, 0)],
|
||||
mode="constant",
|
||||
constant_values=0)
|
||||
assert seed.shape == (state_size,), "Wrong seed.shape: %s" % seed.shape
|
||||
@ -157,6 +161,13 @@ def _shape_tensor(shape):
|
||||
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
|
||||
|
||||
|
||||
def _convert_to_state_tensor(t):
|
||||
if isinstance(t, list):
|
||||
# to avoid out-of-range error from ops.convert_to_tensor
|
||||
t = list(map(_uint_to_int, t))
|
||||
return ops.convert_to_tensor(t, dtype=STATE_TYPE)
|
||||
|
||||
|
||||
@tf_export("random.experimental.Generator")
|
||||
class Generator(tracking.AutoTrackable):
|
||||
"""Random-number generator.
|
||||
@ -164,57 +175,171 @@ class Generator(tracking.AutoTrackable):
|
||||
It uses Variable to manage its internal state, and allows choosing an
|
||||
Random-Number-Generation (RNG) algorithm.
|
||||
|
||||
CPU and GPU with the same algorithm and seed will generate the same integer
|
||||
random numbers. Float-point results (such as the output of `normal`) may have
|
||||
small numerical discrepancies between CPU and GPU.
|
||||
|
||||
Because of different counter-reservation schemes, TPU's integer random numbers
|
||||
will be different from CPU/GPU even with the same algorithm and seed.
|
||||
Also, TPU uses different sampling algorithms for some distributions
|
||||
(e.g. using reverse CDF for sampling normal distribution instead of
|
||||
Box-Muller used by CPU/GPU). Harmonizing TPU's RNG behavior with CPU/GPU is
|
||||
work in progress.
|
||||
CPU, GPU and TPU with the same algorithm and seed will generate the same
|
||||
integer random numbers. Float-point results (such as the output of `normal`)
|
||||
may have small numerical discrepancies between CPU and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, copy_from=None, seed=None, algorithm=None):
|
||||
def __init__(self, copy_from=None, state=None, alg=None):
|
||||
"""Creates a generator.
|
||||
|
||||
The new generator will be initialized by one of the following ways, with
|
||||
decreasing precedence:
|
||||
(1) If `copy_from` is not None, the new generator is initialized by copying
|
||||
information from another generator.
|
||||
(3) If `state` and `alg` are not None (they must be set together), the new
|
||||
generator is initialized by a state.
|
||||
|
||||
Args:
|
||||
copy_from: (optional) a generator to be copied from.
|
||||
seed: (optional) the seed for the RNG. If None, it will be chosen
|
||||
nondeterministically
|
||||
algorithm: (optional) the RNG algorithm. If None, it will be
|
||||
auto-selected.
|
||||
copy_from: a generator to be copied from.
|
||||
state: a vector of dtype STATE_TYPE representing the initial state of the
|
||||
RNG, whose length and semantics are algorithm-specific.
|
||||
alg: the RNG algorithm. Possible values are RNG_ALG_PHILOX for the
|
||||
Philox algorithm and RNG_ALG_THREEFRY for the ThreeFry
|
||||
algorithm (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
|
||||
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
|
||||
"""
|
||||
if copy_from is None:
|
||||
if algorithm is None:
|
||||
# TODO(wangpeng): more sophisticated algorithm selection
|
||||
algorithm = DEFAULT_ALGORITHM
|
||||
if seed is None:
|
||||
state = non_deterministic_ints(shape=[_get_state_size(algorithm)],
|
||||
dtype=SEED_TYPE)
|
||||
else:
|
||||
state = create_rng_state(seed, algorithm)
|
||||
self._state_var = variables.Variable(state,
|
||||
dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._alg_var = algorithm
|
||||
else:
|
||||
assert seed is None
|
||||
self._state_var = variables.Variable(copy_from.state,
|
||||
dtype=STATE_TYPE,
|
||||
if copy_from is not None:
|
||||
# All other arguments should be None
|
||||
assert (alg or state) is None
|
||||
self._state_var = variables.Variable(copy_from.state, dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._alg_var = copy_from.algorithm
|
||||
|
||||
def reset(self, seed):
|
||||
"""Resets the generator.
|
||||
else:
|
||||
assert alg is not None and state is not None
|
||||
state = _convert_to_state_tensor(state)
|
||||
state.shape.assert_is_compatible_with([_get_state_size(alg)])
|
||||
self._state_var = variables.Variable(state, dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._alg_var = alg
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state, alg):
|
||||
"""Creates a generator from a state.
|
||||
|
||||
See `__init__` for description of `state` and `alg`.
|
||||
|
||||
Args:
|
||||
seed: the seed to reset the RNG to.
|
||||
state: the new state.
|
||||
alg: the RNG algorithm.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
return cls(alg=alg, state=state)
|
||||
|
||||
@classmethod
|
||||
def from_seed(cls, seed, alg=None):
|
||||
"""Creates a generator from a seed.
|
||||
|
||||
A seed is a 1024-bit unsigned integer represented either as a Python
|
||||
integer or a vector of integers. Seeds shorter than 1024-bit will be
|
||||
padded. The padding, the internal structure of a seed and the way a seed
|
||||
is converted to a state are all opaque (unspecified). The only semantics
|
||||
specification of seeds is that two different seeds are likely to produce
|
||||
two independent generators (but no guarantee).
|
||||
|
||||
Args:
|
||||
seed: the seed for the RNG.
|
||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
if alg is None:
|
||||
# TODO(wangpeng): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
state = create_rng_state(seed, alg)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
@classmethod
|
||||
def from_non_deterministic_state(cls, alg=None):
|
||||
"""Creates a generator by non-deterministically initializing its state.
|
||||
|
||||
The source of the non-determinism will be platform- and time-dependent.
|
||||
|
||||
Args:
|
||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
if alg is None:
|
||||
# TODO(wangpeng): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
||||
dtype=SEED_TYPE)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
@classmethod
|
||||
def from_key_counter(cls, key, counter, alg):
|
||||
"""Creates a generator from a key and a counter.
|
||||
|
||||
This constructor only applies if the algorithm is a counter-based algorithm.
|
||||
See method `key` for the meaning of "key" and "counter".
|
||||
|
||||
Args:
|
||||
key: the key for the RNG, a scalar of type STATE_TYPE.
|
||||
counter: a vector of dtype STATE_TYPE representing the initial counter for
|
||||
the RNG, whose length is algorithm-specific.,
|
||||
alg: the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
counter = _convert_to_state_tensor(counter)
|
||||
key = _convert_to_state_tensor(key)
|
||||
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
||||
key.shape.assert_is_compatible_with([])
|
||||
key = array_ops.reshape(key, [1])
|
||||
state = array_ops.concat([counter, key], 0)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
def reset(self, state):
|
||||
"""Resets the generator by a new state.
|
||||
|
||||
See `__init__` for the meaning of "state".
|
||||
|
||||
Args:
|
||||
state: the new state.
|
||||
"""
|
||||
state = _convert_to_state_tensor(state)
|
||||
state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)])
|
||||
self._state_var.assign(state)
|
||||
|
||||
def reset_from_seed(self, seed):
|
||||
"""Resets the generator by a new seed.
|
||||
|
||||
See `from_seed` for the meaning of "seed".
|
||||
|
||||
Args:
|
||||
seed: the new seed.
|
||||
"""
|
||||
state = create_rng_state(seed, self.algorithm)
|
||||
self._state_var.assign(state)
|
||||
|
||||
def reset_from_key_counter(self, key, counter):
|
||||
"""Resets the generator by a new key-counter pair.
|
||||
|
||||
See `from_key_counter` for the meaning of "key" and "counter".
|
||||
|
||||
Args:
|
||||
key: the new key.
|
||||
counter: the new counter.
|
||||
"""
|
||||
counter = _convert_to_state_tensor(counter)
|
||||
key = _convert_to_state_tensor(key)
|
||||
counter.shape.assert_is_compatible_with(
|
||||
[_get_state_size(self.algorithm) - 1])
|
||||
key.shape.assert_is_compatible_with([])
|
||||
key = array_ops.reshape(key, [1])
|
||||
state = array_ops.concat([counter, key], 0)
|
||||
self._state_var.assign(state)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""The internal state of the RNG."""
|
||||
@ -235,7 +360,7 @@ class Generator(tracking.AutoTrackable):
|
||||
|
||||
For a counter-base RNG algorithm such as Philox and ThreeFry (as
|
||||
described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
|
||||
(https://www.thesalmons.org/john/random123/papers/random123sc11.pdf)),
|
||||
[https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]),
|
||||
the RNG state consists of two parts: counter and key. The output is
|
||||
generated via the formula: output=hash(key, counter), i.e. a hashing of
|
||||
the counter parametrized by the key. Two RNGs with two different keys can
|
||||
@ -535,7 +660,7 @@ class Generator(tracking.AutoTrackable):
|
||||
alg = self.algorithm
|
||||
if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
|
||||
keys = self._make_int64_keys(shape=[count])
|
||||
return [Generator(seed=_key_to_state(alg, key), algorithm=alg)
|
||||
return [Generator(state=_key_to_state(alg, key), alg=alg)
|
||||
for key in keys.numpy()]
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm id: %s" % alg)
|
||||
@ -551,31 +676,27 @@ global_generator = None
|
||||
def get_global_generator():
|
||||
global global_generator
|
||||
if global_generator is None:
|
||||
global_generator = Generator()
|
||||
global_generator = Generator.from_non_deterministic_state()
|
||||
return global_generator
|
||||
|
||||
|
||||
@tf_export("random.experimental.set_global_generator")
|
||||
def set_global_generator(generator):
|
||||
"""Replaces the global generator with another `Generator` object.
|
||||
|
||||
This function creates a new Generator object (and the Variable object within),
|
||||
which does not work well with tf.function because (1) tf.function puts
|
||||
restrictions on Variable creation thus reset_global_generator can't be freely
|
||||
used inside tf.function; (2) redirecting a global variable to
|
||||
a new object is problematic with tf.function because the old object may be
|
||||
captured by a 'tf.function'ed function and still be used by it.
|
||||
A 'tf.function'ed function only keeps weak references to variables,
|
||||
so deleting a variable and then calling that function again may raise an
|
||||
error, as demonstrated by
|
||||
random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun .
|
||||
|
||||
Args:
|
||||
generator: the new `Generator` object.
|
||||
"""
|
||||
global global_generator
|
||||
global_generator = generator
|
||||
|
||||
|
||||
# This function creates a new Generator object (and the Variable object within),
|
||||
# which does not work well with tf.function because (1) tf.function puts
|
||||
# restrictions on Variable creation thus reset_global_generator can't be freely
|
||||
# used inside tf.function; (2) redirecting a global variable to
|
||||
# a new object is problematic with tf.function because the old object may be
|
||||
# captured by a 'tf.function'ed function and still be used by it.
|
||||
# A 'tf.function'ed function only keeps weak references to variables,
|
||||
# so deleting a variable and then calling that function again may raise an
|
||||
# error, as demonstrated by
|
||||
# random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun .
|
||||
# The function 'set_global_generator' below also has this problem.
|
||||
@tf_export("random.experimental.reset_global_generator")
|
||||
def reset_global_generator(seed, algorithm=None):
|
||||
global global_generator
|
||||
if algorithm is None:
|
||||
# preserve the old algorithm
|
||||
algorithm = int(get_global_generator().algorithm)
|
||||
global_generator = Generator(seed=seed, algorithm=algorithm)
|
||||
|
@ -96,7 +96,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""
|
||||
shape = [2, 3]
|
||||
count = 6
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
keys1 = gen._make_int64_keys(shape=shape)
|
||||
keys2 = gen._make_int64_keys(shape=shape)
|
||||
self.assertAllDifferent([keys1, keys2])
|
||||
@ -126,42 +126,100 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""Tests that a CPU RNG can split into RNGs on GPU.
|
||||
"""
|
||||
with ops.device("/device:CPU:0"):
|
||||
gen = random.Generator(seed=1234) # gen is on CPU
|
||||
gen = random.Generator.from_seed(1234) # gen is on CPU
|
||||
self.assertRegex("CPU", gen.state.device)
|
||||
with ops.device(test_util.gpu_device_name()):
|
||||
gens = gen.split(count=10) # gens are on GPU
|
||||
self.assertRegex("GPU", gens[0].state.device)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGeneratorCreationInDefun(self):
|
||||
"""Tests creating a Generator in defun.
|
||||
def testReset(self):
|
||||
shape = [2, 3]
|
||||
gen = random.Generator.from_seed(0)
|
||||
for resetter in [
|
||||
lambda g: g.reset(state=[1, 2, 3]),
|
||||
lambda g: g.reset_from_seed(1234),
|
||||
lambda g: g.reset_from_key_counter(key=1, counter=[2, 3]),
|
||||
]:
|
||||
resetter(gen)
|
||||
expected_normal = gen.normal(shape)
|
||||
@def_function.function
|
||||
def f(resetter):
|
||||
resetter(gen)
|
||||
return gen.normal(shape)
|
||||
def check_results(expected_normal, v):
|
||||
self.assertAllEqual(expected_normal, v)
|
||||
check_results(expected_normal, f(resetter))
|
||||
check_results(expected_normal, f(resetter))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGeneratorCreation(self):
|
||||
"""Tests generator creation, in both eager and tf.function.
|
||||
|
||||
The interaction between Generator creation and defun should be the same as
|
||||
tf.Variable.
|
||||
"""
|
||||
seed = 1234
|
||||
shape = [2, 3]
|
||||
with ops.device("/device:CPU:0"):
|
||||
gen = random.Generator(seed=seed)
|
||||
alg = random.RNG_ALG_PHILOX
|
||||
for constructor in [
|
||||
lambda: random.Generator(state=[1, 2, 3], alg=alg),
|
||||
lambda: random.Generator.from_seed(1234),
|
||||
lambda: random.Generator.from_key_counter( # pylint: disable=g-long-lambda
|
||||
key=1, counter=[2, 3], alg=alg),
|
||||
]:
|
||||
gen = constructor()
|
||||
# Tests tf.function
|
||||
expected_normal1 = gen.normal(shape)
|
||||
expected_normal2 = gen.normal(shape)
|
||||
global g_seeded
|
||||
g_seeded = None
|
||||
@def_function.function
|
||||
def f():
|
||||
def f(constructor):
|
||||
global g_seeded
|
||||
global g_unseeded
|
||||
# defun'ed function should only create variables once
|
||||
if g_seeded is None:
|
||||
g_seeded = random.Generator(seed=seed)
|
||||
if g_unseeded is None:
|
||||
g_unseeded = random.Generator()
|
||||
r = g_seeded.normal(shape)
|
||||
r = (r, g_unseeded.normal(shape))
|
||||
return r
|
||||
def check_results(expected_normal, v1, v2):
|
||||
self.assertAllEqual(expected_normal, v1)
|
||||
self.assertAllEqual(shape, v2.shape)
|
||||
check_results(expected_normal1, *f())
|
||||
check_results(expected_normal2, *f())
|
||||
g_seeded = constructor()
|
||||
return g_seeded.normal(shape)
|
||||
def check_results(expected_normal, v):
|
||||
self.assertAllEqual(expected_normal, v)
|
||||
check_results(expected_normal1, f(constructor))
|
||||
check_results(expected_normal2, f(constructor))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGeneratorCreationUnseeded(self):
|
||||
"""Tests generator creation, the unseeded case."""
|
||||
shape = [2, 3]
|
||||
global g_unseeded
|
||||
g_unseeded = None
|
||||
@def_function.function
|
||||
def f():
|
||||
global g_unseeded
|
||||
# defun'ed function should only create variables once
|
||||
if g_unseeded is None:
|
||||
g_unseeded = random.Generator.from_non_deterministic_state()
|
||||
return g_unseeded.normal(shape)
|
||||
self.assertAllEqual(shape, f().shape)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testGeneratorCopy(self):
|
||||
"""Tests copying a generator."""
|
||||
g = random.Generator.from_seed(0)
|
||||
g_copy = random.Generator(g)
|
||||
self.assertAllEqual(g.algorithm, g_copy.algorithm)
|
||||
self.assertAllEqual(g.state.read_value(), g_copy.state.read_value())
|
||||
# Tests tf.function
|
||||
global g_seeded
|
||||
g_seeded = None
|
||||
# Do the same in tf.function
|
||||
@def_function.function
|
||||
def f():
|
||||
global g_seeded
|
||||
# defun'ed function should only create variables once
|
||||
if g_seeded is None:
|
||||
g_seeded = random.Generator(g)
|
||||
self.assertAllEqual(g.algorithm, g_seeded.algorithm)
|
||||
self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value())
|
||||
f()
|
||||
|
||||
@test_util.run_v1_only(
|
||||
("This test is specifically for checking TF1 compatibility. "
|
||||
@ -176,8 +234,8 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
[[-0.3964749, 0.8369565, -0.30946946],
|
||||
[1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32)
|
||||
with self.cached_session() as sess:
|
||||
gen1 = random.Generator(seed=seed)
|
||||
gen2 = random.Generator()
|
||||
gen1 = random.Generator.from_seed(seed)
|
||||
gen2 = random.Generator.from_non_deterministic_state()
|
||||
sess.run((gen1._state_var.initializer, gen2._state_var.initializer))
|
||||
r1 = gen1.normal(shape, dtype=dtypes.float32)
|
||||
r2 = gen2.normal(shape, dtype=dtypes.float32)
|
||||
@ -203,9 +261,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
https://github.com/tensorflow/tensorflow/issues/9171
|
||||
"""
|
||||
shape = (3,)
|
||||
random.get_global_generator().reset(1)
|
||||
random.get_global_generator().reset_from_seed(1)
|
||||
a = random.get_global_generator().normal(shape)
|
||||
random.get_global_generator().reset(1)
|
||||
random.get_global_generator().reset_from_seed(1)
|
||||
b = random.get_global_generator().normal(shape)
|
||||
self.assertAllEqual(a, b)
|
||||
|
||||
@ -214,9 +272,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def f():
|
||||
return random.get_global_generator().normal(shape)
|
||||
|
||||
random.get_global_generator().reset(1)
|
||||
random.get_global_generator().reset_from_seed(1)
|
||||
c = f()
|
||||
random.get_global_generator().reset(1)
|
||||
random.get_global_generator().reset_from_seed(1)
|
||||
d = f()
|
||||
self.assertAllEqual(c, d)
|
||||
self.assertAllEqual(a, c)
|
||||
@ -237,9 +295,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
return random.get_global_generator().normal(shape)
|
||||
|
||||
def compare(fst_includes_print, snd_includes_print):
|
||||
random.get_global_generator().reset(50)
|
||||
random.get_global_generator().reset_from_seed(50)
|
||||
fst = f(fst_includes_print)
|
||||
random.get_global_generator().reset(50)
|
||||
random.get_global_generator().reset_from_seed(50)
|
||||
snd = f(snd_includes_print)
|
||||
self.assertAllEqual(fst, snd)
|
||||
# Now do the above again using accelerated (defunned) 'f'.
|
||||
@ -247,9 +305,9 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
# two different graphs to be generated, hence demonstrating the
|
||||
# insensitivity to graph changes.
|
||||
f_acc = def_function.function(f)
|
||||
random.get_global_generator().reset(50)
|
||||
random.get_global_generator().reset_from_seed(50)
|
||||
fst = f_acc(fst_includes_print)
|
||||
random.get_global_generator().reset(50)
|
||||
random.get_global_generator().reset_from_seed(50)
|
||||
snd = f_acc(snd_includes_print)
|
||||
self.assertAllEqual(fst, snd)
|
||||
|
||||
@ -260,7 +318,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_v2_only
|
||||
def testKey(self):
|
||||
key = 1234
|
||||
gen = random.Generator(seed=[0, 0, key])
|
||||
gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX)
|
||||
got = gen.key
|
||||
self.assertAllEqual(key, got)
|
||||
@def_function.function
|
||||
@ -273,7 +331,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def testSkip(self):
|
||||
key = 1234
|
||||
counter = 5678
|
||||
gen = random.Generator(seed=[counter, 0, key])
|
||||
gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
|
||||
delta = 432
|
||||
gen.skip(delta)
|
||||
new_counter = gen._state_var[0]
|
||||
@ -285,7 +343,8 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
# note how the two seeds for the old op correspond to the seed for the new
|
||||
# op
|
||||
with ops.device(device):
|
||||
gen = random.Generator(seed=[0, seed2, seed1])
|
||||
gen = random.Generator(state=[0, seed2, seed1],
|
||||
alg=random.RNG_ALG_PHILOX)
|
||||
|
||||
# create a graph for the old op in order to call it many times
|
||||
@def_function.function
|
||||
@ -361,10 +420,10 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
seed = 1234
|
||||
shape = [315, 49]
|
||||
with ops.device("/device:CPU:0"):
|
||||
cpu = random.Generator(seed=seed).uniform_full_int(
|
||||
cpu = random.Generator.from_seed(seed).uniform_full_int(
|
||||
shape=shape, dtype=dtype)
|
||||
with ops.device(test_util.gpu_device_name()):
|
||||
gpu = random.Generator(seed=seed).uniform_full_int(
|
||||
gpu = random.Generator.from_seed(seed).uniform_full_int(
|
||||
shape=shape, dtype=dtype)
|
||||
self.assertAllEqual(cpu, gpu)
|
||||
|
||||
@ -374,7 +433,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
minval = 2
|
||||
maxval = 33
|
||||
size = 1000
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
x = gen.uniform(
|
||||
shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
|
||||
self.assertTrue(np.all(x >= minval))
|
||||
@ -383,7 +442,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
@parameterized.parameters(FLOATS)
|
||||
@test_util.run_v2_only
|
||||
def testNormalIsFinite(self, dtype):
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
x = gen.normal(shape=[10000], dtype=dtype).numpy()
|
||||
self.assertTrue(np.all(np.isfinite(x)))
|
||||
|
||||
@ -393,7 +452,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""Use Pearson's Chi-squared test to test for uniformity."""
|
||||
n = 1000
|
||||
seed = 12
|
||||
gen = random.Generator(seed=seed)
|
||||
gen = random.Generator.from_seed(seed)
|
||||
maxval = 1
|
||||
if dtype.is_integer:
|
||||
maxval = 100
|
||||
@ -413,7 +472,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def testDistributionOfNormal(self, dtype):
|
||||
"""Use Anderson-Darling test to test distribution appears normal."""
|
||||
n = 1000
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
x = gen.normal(shape=[n], dtype=dtype).numpy()
|
||||
# The constant 2.492 is the 5% critical value for the Anderson-Darling
|
||||
# test where the mean and variance are known. This test is probabilistic
|
||||
@ -426,7 +485,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""Tests that proper errors are raised.
|
||||
"""
|
||||
shape = [2, 3]
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors.InvalidArgumentError,
|
||||
r"must have shape \[\], not"):
|
||||
@ -466,8 +525,8 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
var.handle, random.RNG_ALG_PHILOX, shape)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testResetGlobalGeneratorBadWithDefun(self):
|
||||
"""Demonstrates that reset_global_generator don't work properly with defun.
|
||||
def testSetGlobalGeneratorBadWithDefun(self):
|
||||
"""Demonstrates that set_global_generator don't work properly with defun.
|
||||
"""
|
||||
shape = (3,)
|
||||
|
||||
@ -475,11 +534,11 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def f():
|
||||
return random.get_global_generator().normal(shape)
|
||||
|
||||
random.reset_global_generator(50)
|
||||
random.set_global_generator(random.Generator.from_seed(50))
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors.NotFoundError, "Resource .+ does not exist"):
|
||||
_ = f()
|
||||
random.reset_global_generator(50)
|
||||
random.set_global_generator(random.Generator.from_seed(50))
|
||||
_ = f()
|
||||
|
||||
@test_util.run_v2_only
|
||||
@ -492,7 +551,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""
|
||||
shape = [3, 4]
|
||||
dtype = dtypes.int32
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
|
||||
with strat.scope():
|
||||
def f():
|
||||
@ -519,7 +578,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
dtype = dtypes.int32
|
||||
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
|
||||
with strat.scope():
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
def f():
|
||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
@ -545,7 +604,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
dtype = dtypes.int32
|
||||
strat = MirroredStrategy(devices=["/cpu:0", test_util.gpu_device_name()])
|
||||
def f():
|
||||
gen = random.Generator(seed=1234)
|
||||
gen = random.Generator.from_seed(1234)
|
||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t = array_ops.stack([t1, t2])
|
||||
@ -575,7 +634,7 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
# different random-number stream. The only obstacle is that op
|
||||
# 'NonDeterministicInts' is not implemented on GPU.)
|
||||
with strat.scope():
|
||||
gen = random.Generator()
|
||||
gen = random.Generator.from_non_deterministic_state()
|
||||
def f():
|
||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
|
@ -18,12 +18,28 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'copy_from\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "binomial"
|
||||
argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_key_counter"
|
||||
argspec: "args=[\'cls\', \'key\', \'counter\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_non_deterministic_state"
|
||||
argspec: "args=[\'cls\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_seed"
|
||||
argspec: "args=[\'cls\', \'seed\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_state"
|
||||
argspec: "args=[\'cls\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_seeds"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
@ -34,6 +50,14 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reset"
|
||||
argspec: "args=[\'self\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_key_counter"
|
||||
argspec: "args=[\'self\', \'key\', \'counter\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_seed"
|
||||
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
|
@ -12,10 +12,6 @@ tf_module {
|
||||
name: "get_global_generator"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_global_generator"
|
||||
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_global_generator"
|
||||
argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -18,12 +18,28 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'copy_from\', \'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'copy_from\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "binomial"
|
||||
argspec: "args=[\'self\', \'shape\', \'counts\', \'probs\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_key_counter"
|
||||
argspec: "args=[\'cls\', \'key\', \'counter\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_non_deterministic_state"
|
||||
argspec: "args=[\'cls\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_seed"
|
||||
argspec: "args=[\'cls\', \'seed\', \'alg\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_state"
|
||||
argspec: "args=[\'cls\', \'state\', \'alg\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_seeds"
|
||||
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'1\'], "
|
||||
@ -34,6 +50,14 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "reset"
|
||||
argspec: "args=[\'self\', \'state\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_key_counter"
|
||||
argspec: "args=[\'self\', \'key\', \'counter\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_from_seed"
|
||||
argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
|
@ -12,10 +12,6 @@ tf_module {
|
||||
name: "get_global_generator"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reset_global_generator"
|
||||
argspec: "args=[\'seed\', \'algorithm\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "set_global_generator"
|
||||
argspec: "args=[\'generator\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user