Allows creating tf.random.Generator under distribution-strategy scopes. Different replicas will get different random-number streams.

All strategies are supported except for CentralStorageStrategy and ParameterServerStrategy.

This CL also removes the CompositeTensor superclass from Generator. Generator is a wrapper around tf.Variable, and because tf.Variable is not a CompositeTensor, Generator can't be a CompositeTensor in theory. Previously we made it a CompositeTensor by returning Variable.handle, but that breaks down when the variable is a DistributedVariable (in cross-replica context).

PiperOrigin-RevId: 350851648
Change-Id: I5f4d77ddb990557fcc9c7336987203ecdaec5b9a
This commit is contained in:
Peng Wang 2021-01-08 15:52:18 -08:00 committed by TensorFlower Gardener
parent 70502be4a5
commit 587ac71f68
12 changed files with 495 additions and 294 deletions

View File

@ -24,6 +24,7 @@
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
gathered at runtime to be used in embedding layer partitioning decisions.
* `tf.keras.metrics.AUC` now support logit predictions.
* Creating `tf.random.Generator` under `tf.distribute.Strategy` scopes is now allowed (except for `tf.distribute.experimental.CentralStorageStrategy` and `tf.distribute.experimental.ParameterServerStrategy`). Different replicas will get different random-number streams.
## Bug Fixes and Other Changes

View File

@ -587,6 +587,22 @@ py_library(
],
)
distribute_py_test(
name = "random_generator_test",
srcs = ["random_generator_test.py"],
main = "random_generator_test.py",
shard_count = 10,
tags = [
"multi_and_single_gpu",
],
deps = [
"//tensorflow/python:stateful_random_ops",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
],
)
tpu_py_test(
name = "tpu_strategy_test",
srcs = ["tpu_strategy_test.py"],

View File

@ -0,0 +1,255 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests tf.random.Generator with distribution strategies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
from absl.testing import parameterized
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import combinations as ds_combinations
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_combinations as combinations
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import stateful_random_ops as rng
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.training.tracking import util as tracking_util
def get_num_local_replicas(strat, values=None):
strat_name = type(strat).__name__
if "MultiWorker" in strat_name or "CollectiveAllReduceStrategy" in strat_name:
if values is None:
values = strat.run(lambda: constant_op.constant(0))
values = strat.experimental_local_results(values)
return len(values)
else:
return strat.num_replicas_in_sync
all_strategies = (strategy_combinations.all_strategies +
strategy_combinations.multiworker_strategies)
class GeneratorTest(test.TestCase, parameterized.TestCase):
def setUp(self):
super(GeneratorTest, self).setUp()
v2_compat.enable_v2_behavior()
config.set_soft_device_placement(False)
def assertAllDifferent(self, tensors):
"""Checks that there are no duplicate elements anywhere among the tensors.
Args:
tensors: a list of tensors. They can have different shapes.
"""
values = [array_ops.reshape(t, shape=[-1]) for t in tensors]
values = array_ops.concat(values, axis=0)
values = self.evaluate(values)
values = values.tolist()
self.assertAllEqual(len(values), len(set(values)))
@ds_combinations.generate(
combinations.combine(
strat=all_strategies,
mode=["eager"]))
def testCrossReplica(self, strat):
"""Tests that RNG can be properly advanced in cross-replica context."""
strat_name = type(strat).__name__
if "CentralStorage" in strat_name:
self.skipTest("Does not work with CentralStorageStrategy yet.")
def read_values(dv):
return [v.read_value() for v in strat.experimental_local_results(dv)]
with strat.scope():
g = rng.Generator.from_seed(1)
s1 = read_values(g.state)
g.normal([3])
g.skip(4)
s2 = read_values(g.state)
self.assertNotAllEqual(s1[0], s2[0])
self.assertEqual(len(s1), len(s2))
for i in range(1, len(s1)):
self.assertAllEqual(s1[0], s1[i])
self.assertAllEqual(s2[0], s2[i])
@ds_combinations.generate(
combinations.combine(
strat=all_strategies,
mode=["eager"],
seeded=[True, False]))
def testDistStrat(self, strat, seeded):
"""Tests RNG with distribution strategies."""
strat_name = type(strat).__name__
if "CentralStorage" in strat_name:
self.skipTest("Does not work with CentralStorageStrategy yet.")
creators = {
True: functools.partial(rng.Generator.from_seed, 1234),
False: rng.Generator.from_non_deterministic_state,
}
shape = [3, 4]
dtype = dtypes.int32
creator = creators[seeded]
with strat.scope():
gen = creator()
@def_function.function
def f():
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
t = array_ops.stack([t1, t2])
return t
results = strat.run(f)
values = strat.experimental_local_results(results)
n = get_num_local_replicas(strat, values)
self.assertAllEqual(n, len(values))
self.assertAllDifferent(values)
@ds_combinations.generate(
combinations.combine(
strat=all_strategies,
mode=["eager"]))
def testDistVarAsTFFunArg(self, strat):
"""Tests that RNG with dist variables can be used as tf.function's arg."""
strat_name = type(strat).__name__
if "CentralStorage" in strat_name:
self.skipTest("Does not work with CentralStorageStrategy yet.")
shape = [3, 4]
dtype = dtypes.int32
with strat.scope():
gen = rng.Generator.from_seed(1234)
@def_function.function
def f(gen): # the main focus
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
t = array_ops.stack([t1, t2])
return t
@def_function.function # required by TPUStrategy.run
def g():
return f(gen)
for _ in range(2):
results = strat.run(g)
values = strat.experimental_local_results(results)
n = get_num_local_replicas(strat, values)
self.assertAllEqual(n, len(values))
self.assertAllDifferent(values)
@ds_combinations.generate(
combinations.combine(
strat1=strategy_combinations.all_strategies,
strat2=strategy_combinations.all_strategies,
mode=["eager"]) +
combinations.combine(
strat1=strategy_combinations.multiworker_strategies,
strat2=[None],
mode=["eager"]))
def testDistStratRestore(self, strat1, strat2):
"""Tests checkpointing and restoring (to possibly different #replicas)."""
if strat2 is None:
strat2 = strat1
strat1_name = type(strat1).__name__
strat2_name = type(strat2).__name__
if "CentralStorage" in strat1_name or "CentralStorage" in strat2_name:
self.skipTest("Does not work with CentralStorageStrategy yet.")
if "Default" in strat1_name or "Default" in strat2_name:
self.skipTest(
"We don't guarantee consistency between strategy and no-strategy.")
fname = os.path.join(self.get_temp_dir(), "checkpoint")
def uniform(strat, g):
@def_function.function
def f():
return g.uniform_full_int([3], dtype=dtypes.int32)
result = strat.run(f)
return strat.experimental_local_results(result)
with strat1.scope():
g1 = rng.Generator.from_seed(1)
with strat2.scope():
g2 = rng.Generator.from_seed(10)
cp1 = tracking_util.Checkpoint(g=g1)
cp2 = tracking_util.Checkpoint(g=g2)
def write_restore_compare():
cp1.write(fname)
r1 = uniform(strat1, g1)
cp2.restore(fname)
r2 = uniform(strat2, g2)
# Tests that overlapping replicas are properly restored.
n1 = get_num_local_replicas(strat1)
n2 = get_num_local_replicas(strat2)
n = min(n1, n2)
self.assertAllEqual(r1[:n], r2[:n])
# Run multiple times so that cp1.write is called in various RNG states
for _ in range(2):
write_restore_compare()
@ds_combinations.generate(
combinations.combine(
strat=strategy_combinations.all_strategies,
mode=["eager"],
is_save_in_scope=[True, False]))
def testSavedModel(self, strat, is_save_in_scope):
strat_name = type(strat).__name__
if "CentralStorage" in strat_name:
self.skipTest("Does not work with CentralStorageStrategy yet.")
class CustomModule(module.Module):
def __init__(self):
super(CustomModule, self).__init__()
self.g = rng.Generator.from_seed(0)
@def_function.function
def __call__(self):
return self.g.state
@def_function.function
def mutate(self):
self.g.normal([])
with strat.scope():
m = CustomModule()
m.mutate()
state_before = m()
path = os.path.join(self.get_temp_dir(), "saved_model")
if is_save_in_scope:
with strat.scope():
save.save(m, path)
else:
save.save(m, path)
with strat.scope():
m.mutate()
state_before_2 = m()
imported = load.load(path)
state_after = imported()
self.assertAllEqual(state_before, state_after)
imported.mutate()
state_after_2 = imported()
self.assertAllEqual(state_before_2, state_after_2)
if __name__ == "__main__":
multi_process_runner.test_main()

View File

@ -39,7 +39,6 @@ from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateful_random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.util.tf_export import keras_export
ResizeMethod = image_ops.ResizeMethod
@ -1356,44 +1355,20 @@ class RandomWidth(PreprocessingLayer):
return dict(list(base_config.items()) + list(config.items()))
# TODO(b/147877541, b/158339556): This class is added to temporarily enable
# creating generators within distribution strategies. Remove it when the proper
# API is in place.
class _RandomGenerator(stateful_random_ops.Generator):
"""A subclass that allows creation inside distribution strategies.
This is a temporary solution to allow creating tf.random.Generator inside
distribution strategies. It will be removed when proper API is in place.
All replicas will have the same RNG state and generate the same random
numbers.
"""
# TODO(b/157995497): Temporarily use primary variable handle inside cross
# replica context.
@property
def state(self):
"""The internal state of the RNG."""
state_var = self._state_var
try:
_ = getattr(state_var, 'handle')
return state_var
except ValueError:
return state_var.values[0]
def _create_variable(self, *args, **kwargs):
# This function does the same thing as the base class's namesake, except
# that it skips the distribution-strategy check. When we are inside a
# distribution-strategy scope, variables.Variable will pick a proper
# variable class (e.g. MirroredVariable).
return variables.Variable(*args, **kwargs)
def make_generator(seed=None):
"""Creates a random generator.
Args:
seed: the seed to initialize the generator. If None, the generator will be
initialized non-deterministically.
Returns:
A generator object.
"""
if seed:
return _RandomGenerator.from_seed(seed)
return stateful_random_ops.Generator.from_seed(seed)
else:
return _RandomGenerator.from_non_deterministic_state()
return stateful_random_ops.Generator.from_non_deterministic_state()
def get_interpolation(interpolation):

View File

@ -41,6 +41,8 @@ class ImagePreprocessingDistributionTest(
preprocessing_test_utils.PreprocessingLayerTest):
def test_distribution(self, distribution):
if "CentralStorage" in type(distribution).__name__:
self.skipTest("Does not work with CentralStorageStrategy yet.")
# TODO(b/159738418): large image input causes OOM in ubuntu multi gpu.
np_images = np.random.random((32, 32, 32, 3)).astype(np.float32)
image_dataset = dataset_ops.Dataset.from_tensor_slices(np_images).batch(

View File

@ -1082,8 +1082,6 @@ class RandomRotationTest(keras_parameterized.TestCase):
def test_distribution_strategy(self):
"""Tests that RandomRotation can be created within distribution strategies.
And that replicas got the same random result.
"""
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
with testing_utils.use_gpu():
@ -1093,7 +1091,6 @@ class RandomRotationTest(keras_parameterized.TestCase):
output = strat.run(lambda: layer(input_images, training=True))
values = output.values
self.assertAllEqual(2, len(values))
self.assertAllClose(values[0], values[1], rtol=1e-5)
@testing_utils.run_v2_only
def test_config_with_custom_name(self):

View File

@ -25,17 +25,14 @@ import six
from tensorflow.python.compat import compat
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import gen_stateless_random_ops_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util.tf_export import tf_export
@ -235,40 +232,15 @@ def _convert_to_state_tensor(t):
return ops.convert_to_tensor(t, dtype=STATE_TYPE)
class GeneratorSpec(type_spec.TypeSpec):
"""TypeSpec for Generator."""
def __init__(self, shape=None, dtype=None, alg=None):
self.shape = shape
self.dtype = dtype
self.alg = alg
@property
def _component_specs(self):
return (tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource),)
def _to_components(self, value):
return (value.state.handle,)
def _from_components(self, components):
assert isinstance(components, (list, tuple))
assert len(components) == 1
handle = components[0]
state_var = resource_variable_ops.BaseResourceVariable(
handle=handle, shape=self.shape, dtype=self.dtype,
trainable=False, handle_deleter=object(), handle_name="RNGVar")
return Generator(state=state_var, alg=self.alg)
@property
def value_type(self):
return Generator
def _serialize(self):
return (self.shape, self.dtype, self.alg)
def get_replica_id():
rctx = ds_context.get_replica_context()
if rctx is None:
return None
return rctx.replica_id_in_sync_group
@tf_export("random.Generator", "random.experimental.Generator")
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
class Generator(tracking.AutoTrackable):
"""Random-number generator.
Example:
@ -320,8 +292,137 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
>>> g = tf.random.get_global_generator()
>>> g.normal(shape=(2, 3))
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
When creating a generator inside a `tf.distribute.Strategy` scope, each
replica will get a different stream of random numbers.
Note: `tf.distribute.experimental.CentralStorageStrategy` and
`tf.distribute.experimental.ParameterServerStrategy` are not supported yet.
For example, in this code:
```
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
with strat.scope():
g = tf.random.Generator.from_seed(1)
def f():
return g.normal([])
results = strat.run(f).values
```
`results[0]` and `results[1]` will have different values.
If the generator is seeded (e.g. created via `Generator.from_seed`), the
random numbers will be determined by the seed, even though different replicas
get different numbers. One can think of a random number generated on a
replica as a hash of the replica ID and a "master" random number that may be
common to all replicas. Hence, the whole system is still deterministic.
(Note that the random numbers on different replicas are not correlated, even
if they are deterministically determined by the same seed. They are not
correlated in the sense that no matter what statistics one calculates on them,
there won't be any discernable correlation.)
Generators can be freely saved and restored using `tf.train.Checkpoint`. The
checkpoint can be restored in a distribution strategy with a different number
of replicas than the original strategy. If a replica ID is present in both the
original and the new distribution strategy, its state will be properly
restored (i.e. the random-number stream from the restored point will be the
same as that from the saving point) unless the replicas have already diverged
in their RNG call traces before saving (e.g. one replica has made one RNG call
while another has made two RNG calls). We don't have such guarantee if the
generator is saved in a strategy scope and restored outside of any strategy
scope, or vice versa.
"""
@classmethod
def from_state(cls, state, alg):
"""Creates a generator from a state.
See `__init__` for description of `state` and `alg`.
Args:
state: the new state.
alg: the RNG algorithm.
Returns:
The new generator.
"""
return cls(alg=alg, state=state)
@classmethod
def from_seed(cls, seed, alg=None):
"""Creates a generator from a seed.
A seed is a 1024-bit unsigned integer represented either as a Python
integer or a vector of integers. Seeds shorter than 1024-bit will be
padded. The padding, the internal structure of a seed and the way a seed
is converted to a state are all opaque (unspecified). The only semantics
specification of seeds is that two different seeds are likely to produce
two independent generators (but no guarantee).
Args:
seed: the seed for the RNG.
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
if alg is None:
# TODO(b/170668986): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = _convert_alg_to_int(alg)
state = create_rng_state(seed, alg)
return cls(state=state, alg=alg)
@classmethod
def from_non_deterministic_state(cls, alg=None):
"""Creates a generator by non-deterministically initializing its state.
The source of the non-determinism will be platform- and time-dependent.
Args:
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
if alg is None:
# TODO(b/170668986): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = _convert_alg_to_int(alg)
state = non_deterministic_ints(shape=[_get_state_size(alg)],
dtype=SEED_TYPE)
return cls(state=state, alg=alg)
@classmethod
def from_key_counter(cls, key, counter, alg):
"""Creates a generator from a key and a counter.
This constructor only applies if the algorithm is a counter-based algorithm.
See method `key` for the meaning of "key" and "counter".
Args:
key: the key for the RNG, a scalar of type STATE_TYPE.
counter: a vector of dtype STATE_TYPE representing the initial counter for
the RNG, whose length is algorithm-specific.,
alg: the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
"""
counter = _convert_to_state_tensor(counter)
key = _convert_to_state_tensor(key)
alg = _convert_alg_to_int(alg)
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
key.shape.assert_is_compatible_with([])
key = array_ops.reshape(key, [1])
state = array_ops.concat([counter, key], 0)
return cls(state=state, alg=alg)
def __init__(self, copy_from=None, state=None, alg=None):
"""Creates a generator.
@ -346,24 +447,26 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
The string names `"philox"` and `"threefry"` can also be used.
Note `PHILOX` guarantees the same numbers are produced (given
the same random state) across all architectures (CPU, GPU, XLA etc).
Throws:
ValueError: if the generator is created inside a synchronous
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
because there is ambiguity on how to replicate a generator (e.g. should
it be copied so such each replica will get the same random numbers, or
should it be "split" into different generators that generate
different random numbers).
"""
# TODO(b/175072242): Remove distribution-strategy dependencies in this file.
if ds_context.has_strategy():
self._distribution_strategy = ds_context.get_strategy()
else:
self._distribution_strategy = None
if copy_from is not None:
# All other arguments should be None
assert (alg or state) is None
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
trainable=False)
self._alg = copy_from.algorithm
else:
assert alg is not None and state is not None
if ds_context.has_strategy():
strat_name = type(ds_context.get_strategy()).__name__
# TODO(b/174610856): Support CentralStorageStrategy and
# ParameterServerStrategy.
if "CentralStorage" in strat_name or "ParameterServer" in strat_name:
raise ValueError("%s is not supported yet" % strat_name)
alg = _convert_alg_to_int(alg)
if isinstance(state, variables.Variable):
_check_state_shape(state.shape, alg)
@ -376,7 +479,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
self._alg = alg
def _create_variable(self, *args, **kwargs):
"""Creates a variable, and check that it's not MirroredVariable.
"""Creates a variable.
Args:
*args: positional arguments passed along to `variables.Variable.
@ -385,135 +488,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
Returns:
The created variable.
"""
if ds_context.has_strategy():
raise ValueError(
"Creating a generator within a strategy scope is disallowed, because "
"there is ambiguity on how to replicate a generator (e.g. should it "
"be copied so that each replica gets the same random numbers, or "
"'split' so that each replica gets different random numbers).")
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
var = variables.Variable(*args, **kwargs)
return var
@classmethod
def from_state(cls, state, alg):
"""Creates a generator from a state.
See `__init__` for description of `state` and `alg`.
Args:
state: the new state.
alg: the RNG algorithm.
Returns:
The new generator.
Throws:
ValueError: if the generator is created inside a synchronous
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
because there is ambiguity on how to replicate a generator (e.g. should
it be copied so such each replica will get the same random numbers, or
should it be "split" into different generators that generate
different random numbers).
"""
return cls(alg=alg, state=state)
@classmethod
def from_seed(cls, seed, alg=None):
"""Creates a generator from a seed.
A seed is a 1024-bit unsigned integer represented either as a Python
integer or a vector of integers. Seeds shorter than 1024-bit will be
padded. The padding, the internal structure of a seed and the way a seed
is converted to a state are all opaque (unspecified). The only semantics
specification of seeds is that two different seeds are likely to produce
two independent generators (but no guarantee).
Args:
seed: the seed for the RNG.
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
Throws:
ValueError: if the generator is created inside a synchronous
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
because there is ambiguity on how to replicate a generator (e.g. should
it be copied so such each replica will get the same random numbers, or
should it be "split" into different generators that generate
different random numbers).
"""
if alg is None:
# TODO(wangpeng): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = _convert_alg_to_int(alg)
state = create_rng_state(seed, alg)
return cls(state=state, alg=alg)
@classmethod
def from_non_deterministic_state(cls, alg=None):
"""Creates a generator by non-deterministically initializing its state.
The source of the non-determinism will be platform- and time-dependent.
Args:
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
Throws:
ValueError: if the generator is created inside a synchronous
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
because there is ambiguity on how to replicate a generator (e.g. should
it be copied so such each replica will get the same random numbers, or
should it be "split" into different generators that generate
different random numbers).
"""
if alg is None:
# TODO(wangpeng): more sophisticated algorithm selection
alg = DEFAULT_ALGORITHM
alg = _convert_alg_to_int(alg)
state = non_deterministic_ints(shape=[_get_state_size(alg)],
dtype=SEED_TYPE)
return cls(state=state, alg=alg)
@classmethod
def from_key_counter(cls, key, counter, alg):
"""Creates a generator from a key and a counter.
This constructor only applies if the algorithm is a counter-based algorithm.
See method `key` for the meaning of "key" and "counter".
Args:
key: the key for the RNG, a scalar of type STATE_TYPE.
counter: a vector of dtype STATE_TYPE representing the initial counter for
the RNG, whose length is algorithm-specific.,
alg: the RNG algorithm. If None, it will be auto-selected. See
`__init__` for its possible values.
Returns:
The new generator.
Throws:
ValueError: if the generator is created inside a synchronous
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
because there is ambiguity on how to replicate a generator (e.g. should
it be copied so such each replica will get the same random numbers, or
should it be "split" into different generators that generate
different random numbers).
"""
counter = _convert_to_state_tensor(counter)
key = _convert_to_state_tensor(key)
alg = _convert_alg_to_int(alg)
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
key.shape.assert_is_compatible_with([])
key = array_ops.reshape(key, [1])
state = array_ops.concat([counter, key], 0)
return cls(state=state, alg=alg)
return variables.Variable(*args, **kwargs)
def reset(self, state):
"""Resets the generator by a new state.
@ -556,11 +531,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
state = array_ops.concat([counter, key], 0)
self._state_var.assign(state)
@property
def _type_spec(self):
return GeneratorSpec(shape=self.state.shape, dtype=self.state.dtype,
alg=self.algorithm)
@property
def state(self):
"""The internal state of the RNG."""
@ -614,15 +584,52 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
counter is an unspecified implementation detail.
"""
if compat.forward_compatible(2020, 10, 25):
return gen_stateful_random_ops.rng_read_and_skip(
self.state.handle,
alg=math_ops.cast(self.algorithm, dtypes.int32),
delta=math_ops.cast(delta, dtypes.uint64))
return self._skip(delta)
gen_stateful_random_ops.rng_skip(
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
math_ops.cast(delta, dtypes.int64))
# pylint: enable=g-doc-return-or-yield
def _skip_single_var(self, var, delta):
# TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
return gen_stateful_random_ops.rng_read_and_skip(
var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32),
delta=math_ops.cast(delta, dtypes.uint64))
def _skip(self, delta):
def update_fn(v):
return self._skip_single_var(v, delta)
# TODO(b/170515001): Always call strategy.extended.update after calling it
# from both replica context and cross-replica context is supported.
if values_util.is_saving_non_distributed():
# Assumes replica context with replica_id=0, since we only save the first
# replica.
return update_fn(self.state)
if self._distribution_strategy is not None:
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
if ds_context.in_cross_replica_context():
# Code that operates on all replicas of a variable cannot be saved
# without retracing.
values_util.mark_as_unsaveable()
# In cross-replica context we need to use strategy.extended.update.
return ds_context.get_strategy().extended.update(
self.state, update_fn)
return update_fn(self.state)
def _preprocess_key(self, key):
if self._distribution_strategy is None:
return key
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
replica_id = get_replica_id()
if replica_id is not None:
replica_id = array_ops.stack([replica_id, 0], axis=0)
replica_id = math_ops.cast(replica_id, dtypes.uint64)
# Conceptually: key = hash(key, replica_id)
key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
alg=self.algorithm)
return key
def _prepare_key_counter(self, shape):
delta = math_ops.reduce_prod(shape)
counter_key = self.skip(delta)
@ -630,6 +637,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
dtypes.uint64)
key = self._preprocess_key(key)
return key, counter
# The following functions return a tensor and as a side effect update

View File

@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
from absl.testing import parameterized
import numpy as np
import six
from tensorflow.python.distribute import values as dist_values
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
@ -33,7 +33,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.kernel_tests.random import util as \
random_test_util
@ -45,6 +44,7 @@ from tensorflow.python.ops import stateful_random_ops as \
random
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.tracking import util as tracking_util
g_seeded = None
@ -636,49 +636,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(res1, res2)
self.assertAllEqual(g1.state.read_value(), g2.state.read_value())
@test_util.run_v2_only
def testFunArgAlgIsInt(self):
"""Tests that `algorithm` is `int` when reconstructed from composite tensor.
"""
@def_function.function
def f(g):
self.assertIsInstance(g.algorithm, six.integer_types)
return g.make_seeds(), g
gen = random.Generator.from_seed(123, alg="philox")
f(gen)
@test_util.run_v2_only
def testLimitedRetracingWithCompositeTensors(self):
"""Tests that RNGs with the same shape/dtype won't cause retracing.
"""
trace_count = [0]
@def_function.function
def f(x):
trace_count[0] += 1
return x.normal([])
f(random.Generator.from_seed(1))
f(random.Generator.from_seed(2))
self.assertEqual(trace_count[0], 1)
def testMostSpecificCompatibleType(self):
"""Tests GeneratorSpec.most_specific_compatible_type.
"""
spec = random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32)
res = spec.most_specific_compatible_type(
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32))
self.assertEqual(spec, res)
with self.assertRaisesWithPredicateMatch(ValueError, ""):
spec.most_specific_compatible_type(
tensor_spec.TensorSpec(shape=(2, 3), dtype=dtypes.int32))
with self.assertRaisesWithPredicateMatch(ValueError, ""):
spec.most_specific_compatible_type(
random.GeneratorSpec(shape=(2, 4), dtype=dtypes.int32))
with self.assertRaisesWithPredicateMatch(ValueError, ""):
spec.most_specific_compatible_type(
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
@test_util.run_v2_only
def testCreateOutsideMirroredStrat(self):
"""Tests RNG/MirrorStrategy interaction #1.
@ -702,29 +659,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(2, len(values))
self.assertAllDifferent(values)
@test_util.run_v2_only
def testMirroredStratParaSyncDisallowed(self):
"""Tests that generator creation in MirroredStrategy is disallowed.
"""
creators = [
lambda: random.Generator.from_seed(1234),
random.Generator.from_non_deterministic_state,
]
shape = [3, 4]
dtype = dtypes.int32
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
for creator in creators:
with strat.scope():
with self.assertRaisesWithPredicateMatch(
ValueError, "disallowed"):
creator() # pylint: disable=cell-var-from-loop
def f():
gen = creator() # pylint: disable=cell-var-from-loop
return gen.uniform_full_int(shape=shape, dtype=dtype)
with self.assertRaisesWithPredicateMatch(
ValueError, "disallowed"):
strat.extended.call_for_each_replica(fn=f)
@test_util.run_v2_only
def testMirroredStratParaAsync(self):
"""Tests RNG/MirrorStrategy interaction #2.
@ -764,6 +698,23 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
self.assertAllEqual(r1, r2)
@test_util.run_v2_only
def testRestore(self):
"""Tests save and restore.
"""
fname = os.path.join(self.get_temp_dir(), "checkpoint")
g = random.Generator.from_seed(1)
cp = tracking_util.Checkpoint(g=g)
def write_restore_compare():
cp.write(fname)
r1 = g.uniform([], dtype=dtypes.uint32, minval=None)
cp.restore(fname)
r2 = g.uniform([], dtype=dtypes.uint32, minval=None)
self.assertAllEqual(r1, r2)
# Run multiple times so that cp.write is called in various RNG states
for _ in range(2):
write_restore_compare()
if __name__ == "__main__":
config.set_soft_device_placement(False)

View File

@ -3,7 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "algorithm"

View File

@ -3,7 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "algorithm"

View File

@ -3,7 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "algorithm"

View File

@ -3,7 +3,6 @@ tf_class {
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<type \'object\'>"
member {
name: "algorithm"