Allows creating tf.random.Generator under distribution-strategy scopes. Different replicas will get different random-number streams.
All strategies are supported except for CentralStorageStrategy and ParameterServerStrategy. This CL also removes the CompositeTensor superclass from Generator. Generator is a wrapper around tf.Variable, and because tf.Variable is not a CompositeTensor, Generator can't be a CompositeTensor in theory. Previously we made it a CompositeTensor by returning Variable.handle, but that breaks down when the variable is a DistributedVariable (in cross-replica context). PiperOrigin-RevId: 350851648 Change-Id: I5f4d77ddb990557fcc9c7336987203ecdaec5b9a
This commit is contained in:
parent
70502be4a5
commit
587ac71f68
@ -24,6 +24,7 @@
|
||||
`_tpu_estimator_embedding.py`. This allows embedding lookup statistics
|
||||
gathered at runtime to be used in embedding layer partitioning decisions.
|
||||
* `tf.keras.metrics.AUC` now support logit predictions.
|
||||
* Creating `tf.random.Generator` under `tf.distribute.Strategy` scopes is now allowed (except for `tf.distribute.experimental.CentralStorageStrategy` and `tf.distribute.experimental.ParameterServerStrategy`). Different replicas will get different random-number streams.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
@ -587,6 +587,22 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "random_generator_test",
|
||||
srcs = ["random_generator_test.py"],
|
||||
main = "random_generator_test.py",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:stateful_random_ops",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
],
|
||||
)
|
||||
|
||||
tpu_py_test(
|
||||
name = "tpu_strategy_test",
|
||||
srcs = ["tpu_strategy_test.py"],
|
||||
|
255
tensorflow/python/distribute/random_generator_test.py
Normal file
255
tensorflow/python/distribute/random_generator_test.py
Normal file
@ -0,0 +1,255 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests tf.random.Generator with distribution strategies."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.distribute import combinations as ds_combinations
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_combinations as combinations
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import stateful_random_ops as rng
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.training.tracking import util as tracking_util
|
||||
|
||||
|
||||
def get_num_local_replicas(strat, values=None):
|
||||
strat_name = type(strat).__name__
|
||||
if "MultiWorker" in strat_name or "CollectiveAllReduceStrategy" in strat_name:
|
||||
if values is None:
|
||||
values = strat.run(lambda: constant_op.constant(0))
|
||||
values = strat.experimental_local_results(values)
|
||||
return len(values)
|
||||
else:
|
||||
return strat.num_replicas_in_sync
|
||||
|
||||
|
||||
all_strategies = (strategy_combinations.all_strategies +
|
||||
strategy_combinations.multiworker_strategies)
|
||||
|
||||
|
||||
class GeneratorTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(GeneratorTest, self).setUp()
|
||||
v2_compat.enable_v2_behavior()
|
||||
config.set_soft_device_placement(False)
|
||||
|
||||
def assertAllDifferent(self, tensors):
|
||||
"""Checks that there are no duplicate elements anywhere among the tensors.
|
||||
|
||||
Args:
|
||||
tensors: a list of tensors. They can have different shapes.
|
||||
"""
|
||||
values = [array_ops.reshape(t, shape=[-1]) for t in tensors]
|
||||
values = array_ops.concat(values, axis=0)
|
||||
values = self.evaluate(values)
|
||||
values = values.tolist()
|
||||
self.assertAllEqual(len(values), len(set(values)))
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
strat=all_strategies,
|
||||
mode=["eager"]))
|
||||
def testCrossReplica(self, strat):
|
||||
"""Tests that RNG can be properly advanced in cross-replica context."""
|
||||
strat_name = type(strat).__name__
|
||||
if "CentralStorage" in strat_name:
|
||||
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||
def read_values(dv):
|
||||
return [v.read_value() for v in strat.experimental_local_results(dv)]
|
||||
with strat.scope():
|
||||
g = rng.Generator.from_seed(1)
|
||||
s1 = read_values(g.state)
|
||||
g.normal([3])
|
||||
g.skip(4)
|
||||
s2 = read_values(g.state)
|
||||
self.assertNotAllEqual(s1[0], s2[0])
|
||||
self.assertEqual(len(s1), len(s2))
|
||||
for i in range(1, len(s1)):
|
||||
self.assertAllEqual(s1[0], s1[i])
|
||||
self.assertAllEqual(s2[0], s2[i])
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
strat=all_strategies,
|
||||
mode=["eager"],
|
||||
seeded=[True, False]))
|
||||
def testDistStrat(self, strat, seeded):
|
||||
"""Tests RNG with distribution strategies."""
|
||||
strat_name = type(strat).__name__
|
||||
if "CentralStorage" in strat_name:
|
||||
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||
creators = {
|
||||
True: functools.partial(rng.Generator.from_seed, 1234),
|
||||
False: rng.Generator.from_non_deterministic_state,
|
||||
}
|
||||
shape = [3, 4]
|
||||
dtype = dtypes.int32
|
||||
creator = creators[seeded]
|
||||
with strat.scope():
|
||||
gen = creator()
|
||||
@def_function.function
|
||||
def f():
|
||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t = array_ops.stack([t1, t2])
|
||||
return t
|
||||
results = strat.run(f)
|
||||
values = strat.experimental_local_results(results)
|
||||
n = get_num_local_replicas(strat, values)
|
||||
self.assertAllEqual(n, len(values))
|
||||
self.assertAllDifferent(values)
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
strat=all_strategies,
|
||||
mode=["eager"]))
|
||||
def testDistVarAsTFFunArg(self, strat):
|
||||
"""Tests that RNG with dist variables can be used as tf.function's arg."""
|
||||
strat_name = type(strat).__name__
|
||||
if "CentralStorage" in strat_name:
|
||||
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||
shape = [3, 4]
|
||||
dtype = dtypes.int32
|
||||
with strat.scope():
|
||||
gen = rng.Generator.from_seed(1234)
|
||||
@def_function.function
|
||||
def f(gen): # the main focus
|
||||
t1 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t2 = gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
t = array_ops.stack([t1, t2])
|
||||
return t
|
||||
@def_function.function # required by TPUStrategy.run
|
||||
def g():
|
||||
return f(gen)
|
||||
for _ in range(2):
|
||||
results = strat.run(g)
|
||||
values = strat.experimental_local_results(results)
|
||||
n = get_num_local_replicas(strat, values)
|
||||
self.assertAllEqual(n, len(values))
|
||||
self.assertAllDifferent(values)
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
strat1=strategy_combinations.all_strategies,
|
||||
strat2=strategy_combinations.all_strategies,
|
||||
mode=["eager"]) +
|
||||
combinations.combine(
|
||||
strat1=strategy_combinations.multiworker_strategies,
|
||||
strat2=[None],
|
||||
mode=["eager"]))
|
||||
def testDistStratRestore(self, strat1, strat2):
|
||||
"""Tests checkpointing and restoring (to possibly different #replicas)."""
|
||||
if strat2 is None:
|
||||
strat2 = strat1
|
||||
strat1_name = type(strat1).__name__
|
||||
strat2_name = type(strat2).__name__
|
||||
if "CentralStorage" in strat1_name or "CentralStorage" in strat2_name:
|
||||
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||
if "Default" in strat1_name or "Default" in strat2_name:
|
||||
self.skipTest(
|
||||
"We don't guarantee consistency between strategy and no-strategy.")
|
||||
fname = os.path.join(self.get_temp_dir(), "checkpoint")
|
||||
def uniform(strat, g):
|
||||
@def_function.function
|
||||
def f():
|
||||
return g.uniform_full_int([3], dtype=dtypes.int32)
|
||||
result = strat.run(f)
|
||||
return strat.experimental_local_results(result)
|
||||
with strat1.scope():
|
||||
g1 = rng.Generator.from_seed(1)
|
||||
with strat2.scope():
|
||||
g2 = rng.Generator.from_seed(10)
|
||||
cp1 = tracking_util.Checkpoint(g=g1)
|
||||
cp2 = tracking_util.Checkpoint(g=g2)
|
||||
def write_restore_compare():
|
||||
cp1.write(fname)
|
||||
r1 = uniform(strat1, g1)
|
||||
cp2.restore(fname)
|
||||
r2 = uniform(strat2, g2)
|
||||
# Tests that overlapping replicas are properly restored.
|
||||
n1 = get_num_local_replicas(strat1)
|
||||
n2 = get_num_local_replicas(strat2)
|
||||
n = min(n1, n2)
|
||||
self.assertAllEqual(r1[:n], r2[:n])
|
||||
# Run multiple times so that cp1.write is called in various RNG states
|
||||
for _ in range(2):
|
||||
write_restore_compare()
|
||||
|
||||
@ds_combinations.generate(
|
||||
combinations.combine(
|
||||
strat=strategy_combinations.all_strategies,
|
||||
mode=["eager"],
|
||||
is_save_in_scope=[True, False]))
|
||||
def testSavedModel(self, strat, is_save_in_scope):
|
||||
strat_name = type(strat).__name__
|
||||
if "CentralStorage" in strat_name:
|
||||
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||
|
||||
class CustomModule(module.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(CustomModule, self).__init__()
|
||||
self.g = rng.Generator.from_seed(0)
|
||||
|
||||
@def_function.function
|
||||
def __call__(self):
|
||||
return self.g.state
|
||||
|
||||
@def_function.function
|
||||
def mutate(self):
|
||||
self.g.normal([])
|
||||
|
||||
with strat.scope():
|
||||
m = CustomModule()
|
||||
m.mutate()
|
||||
state_before = m()
|
||||
path = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
if is_save_in_scope:
|
||||
with strat.scope():
|
||||
save.save(m, path)
|
||||
else:
|
||||
save.save(m, path)
|
||||
with strat.scope():
|
||||
m.mutate()
|
||||
state_before_2 = m()
|
||||
|
||||
imported = load.load(path)
|
||||
state_after = imported()
|
||||
self.assertAllEqual(state_before, state_after)
|
||||
imported.mutate()
|
||||
state_after_2 = imported()
|
||||
self.assertAllEqual(state_before_2, state_after_2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multi_process_runner.test_main()
|
@ -39,7 +39,6 @@ from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import stateful_random_ops
|
||||
from tensorflow.python.ops import stateless_random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
ResizeMethod = image_ops.ResizeMethod
|
||||
@ -1356,44 +1355,20 @@ class RandomWidth(PreprocessingLayer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
# TODO(b/147877541, b/158339556): This class is added to temporarily enable
|
||||
# creating generators within distribution strategies. Remove it when the proper
|
||||
# API is in place.
|
||||
class _RandomGenerator(stateful_random_ops.Generator):
|
||||
"""A subclass that allows creation inside distribution strategies.
|
||||
|
||||
This is a temporary solution to allow creating tf.random.Generator inside
|
||||
distribution strategies. It will be removed when proper API is in place.
|
||||
|
||||
All replicas will have the same RNG state and generate the same random
|
||||
numbers.
|
||||
"""
|
||||
|
||||
# TODO(b/157995497): Temporarily use primary variable handle inside cross
|
||||
# replica context.
|
||||
@property
|
||||
def state(self):
|
||||
"""The internal state of the RNG."""
|
||||
state_var = self._state_var
|
||||
try:
|
||||
_ = getattr(state_var, 'handle')
|
||||
return state_var
|
||||
except ValueError:
|
||||
return state_var.values[0]
|
||||
|
||||
def _create_variable(self, *args, **kwargs):
|
||||
# This function does the same thing as the base class's namesake, except
|
||||
# that it skips the distribution-strategy check. When we are inside a
|
||||
# distribution-strategy scope, variables.Variable will pick a proper
|
||||
# variable class (e.g. MirroredVariable).
|
||||
return variables.Variable(*args, **kwargs)
|
||||
|
||||
|
||||
def make_generator(seed=None):
|
||||
"""Creates a random generator.
|
||||
|
||||
Args:
|
||||
seed: the seed to initialize the generator. If None, the generator will be
|
||||
initialized non-deterministically.
|
||||
|
||||
Returns:
|
||||
A generator object.
|
||||
"""
|
||||
if seed:
|
||||
return _RandomGenerator.from_seed(seed)
|
||||
return stateful_random_ops.Generator.from_seed(seed)
|
||||
else:
|
||||
return _RandomGenerator.from_non_deterministic_state()
|
||||
return stateful_random_ops.Generator.from_non_deterministic_state()
|
||||
|
||||
|
||||
def get_interpolation(interpolation):
|
||||
|
@ -41,6 +41,8 @@ class ImagePreprocessingDistributionTest(
|
||||
preprocessing_test_utils.PreprocessingLayerTest):
|
||||
|
||||
def test_distribution(self, distribution):
|
||||
if "CentralStorage" in type(distribution).__name__:
|
||||
self.skipTest("Does not work with CentralStorageStrategy yet.")
|
||||
# TODO(b/159738418): large image input causes OOM in ubuntu multi gpu.
|
||||
np_images = np.random.random((32, 32, 32, 3)).astype(np.float32)
|
||||
image_dataset = dataset_ops.Dataset.from_tensor_slices(np_images).batch(
|
||||
|
@ -1082,8 +1082,6 @@ class RandomRotationTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_distribution_strategy(self):
|
||||
"""Tests that RandomRotation can be created within distribution strategies.
|
||||
|
||||
And that replicas got the same random result.
|
||||
"""
|
||||
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||
with testing_utils.use_gpu():
|
||||
@ -1093,7 +1091,6 @@ class RandomRotationTest(keras_parameterized.TestCase):
|
||||
output = strat.run(lambda: layer(input_images, training=True))
|
||||
values = output.values
|
||||
self.assertAllEqual(2, len(values))
|
||||
self.assertAllClose(values[0], values[1], rtol=1e-5)
|
||||
|
||||
@testing_utils.run_v2_only
|
||||
def test_config_with_custom_name(self):
|
||||
|
@ -25,17 +25,14 @@ import six
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import values_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import gen_stateless_random_ops_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -235,40 +232,15 @@ def _convert_to_state_tensor(t):
|
||||
return ops.convert_to_tensor(t, dtype=STATE_TYPE)
|
||||
|
||||
|
||||
class GeneratorSpec(type_spec.TypeSpec):
|
||||
"""TypeSpec for Generator."""
|
||||
|
||||
def __init__(self, shape=None, dtype=None, alg=None):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.alg = alg
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
return (tensor_spec.TensorSpec(shape=(), dtype=dtypes.resource),)
|
||||
|
||||
def _to_components(self, value):
|
||||
return (value.state.handle,)
|
||||
|
||||
def _from_components(self, components):
|
||||
assert isinstance(components, (list, tuple))
|
||||
assert len(components) == 1
|
||||
handle = components[0]
|
||||
state_var = resource_variable_ops.BaseResourceVariable(
|
||||
handle=handle, shape=self.shape, dtype=self.dtype,
|
||||
trainable=False, handle_deleter=object(), handle_name="RNGVar")
|
||||
return Generator(state=state_var, alg=self.alg)
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return Generator
|
||||
|
||||
def _serialize(self):
|
||||
return (self.shape, self.dtype, self.alg)
|
||||
def get_replica_id():
|
||||
rctx = ds_context.get_replica_context()
|
||||
if rctx is None:
|
||||
return None
|
||||
return rctx.replica_id_in_sync_group
|
||||
|
||||
|
||||
@tf_export("random.Generator", "random.experimental.Generator")
|
||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
class Generator(tracking.AutoTrackable):
|
||||
"""Random-number generator.
|
||||
|
||||
Example:
|
||||
@ -320,8 +292,137 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
>>> g = tf.random.get_global_generator()
|
||||
>>> g.normal(shape=(2, 3))
|
||||
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
|
||||
|
||||
When creating a generator inside a `tf.distribute.Strategy` scope, each
|
||||
replica will get a different stream of random numbers.
|
||||
|
||||
Note: `tf.distribute.experimental.CentralStorageStrategy` and
|
||||
`tf.distribute.experimental.ParameterServerStrategy` are not supported yet.
|
||||
|
||||
For example, in this code:
|
||||
|
||||
```
|
||||
strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
||||
with strat.scope():
|
||||
g = tf.random.Generator.from_seed(1)
|
||||
def f():
|
||||
return g.normal([])
|
||||
results = strat.run(f).values
|
||||
```
|
||||
|
||||
`results[0]` and `results[1]` will have different values.
|
||||
|
||||
If the generator is seeded (e.g. created via `Generator.from_seed`), the
|
||||
random numbers will be determined by the seed, even though different replicas
|
||||
get different numbers. One can think of a random number generated on a
|
||||
replica as a hash of the replica ID and a "master" random number that may be
|
||||
common to all replicas. Hence, the whole system is still deterministic.
|
||||
|
||||
(Note that the random numbers on different replicas are not correlated, even
|
||||
if they are deterministically determined by the same seed. They are not
|
||||
correlated in the sense that no matter what statistics one calculates on them,
|
||||
there won't be any discernable correlation.)
|
||||
|
||||
Generators can be freely saved and restored using `tf.train.Checkpoint`. The
|
||||
checkpoint can be restored in a distribution strategy with a different number
|
||||
of replicas than the original strategy. If a replica ID is present in both the
|
||||
original and the new distribution strategy, its state will be properly
|
||||
restored (i.e. the random-number stream from the restored point will be the
|
||||
same as that from the saving point) unless the replicas have already diverged
|
||||
in their RNG call traces before saving (e.g. one replica has made one RNG call
|
||||
while another has made two RNG calls). We don't have such guarantee if the
|
||||
generator is saved in a strategy scope and restored outside of any strategy
|
||||
scope, or vice versa.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state, alg):
|
||||
"""Creates a generator from a state.
|
||||
|
||||
See `__init__` for description of `state` and `alg`.
|
||||
|
||||
Args:
|
||||
state: the new state.
|
||||
alg: the RNG algorithm.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
return cls(alg=alg, state=state)
|
||||
|
||||
@classmethod
|
||||
def from_seed(cls, seed, alg=None):
|
||||
"""Creates a generator from a seed.
|
||||
|
||||
A seed is a 1024-bit unsigned integer represented either as a Python
|
||||
integer or a vector of integers. Seeds shorter than 1024-bit will be
|
||||
padded. The padding, the internal structure of a seed and the way a seed
|
||||
is converted to a state are all opaque (unspecified). The only semantics
|
||||
specification of seeds is that two different seeds are likely to produce
|
||||
two independent generators (but no guarantee).
|
||||
|
||||
Args:
|
||||
seed: the seed for the RNG.
|
||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
if alg is None:
|
||||
# TODO(b/170668986): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
alg = _convert_alg_to_int(alg)
|
||||
state = create_rng_state(seed, alg)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
@classmethod
|
||||
def from_non_deterministic_state(cls, alg=None):
|
||||
"""Creates a generator by non-deterministically initializing its state.
|
||||
|
||||
The source of the non-determinism will be platform- and time-dependent.
|
||||
|
||||
Args:
|
||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
if alg is None:
|
||||
# TODO(b/170668986): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
alg = _convert_alg_to_int(alg)
|
||||
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
||||
dtype=SEED_TYPE)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
@classmethod
|
||||
def from_key_counter(cls, key, counter, alg):
|
||||
"""Creates a generator from a key and a counter.
|
||||
|
||||
This constructor only applies if the algorithm is a counter-based algorithm.
|
||||
See method `key` for the meaning of "key" and "counter".
|
||||
|
||||
Args:
|
||||
key: the key for the RNG, a scalar of type STATE_TYPE.
|
||||
counter: a vector of dtype STATE_TYPE representing the initial counter for
|
||||
the RNG, whose length is algorithm-specific.,
|
||||
alg: the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
"""
|
||||
counter = _convert_to_state_tensor(counter)
|
||||
key = _convert_to_state_tensor(key)
|
||||
alg = _convert_alg_to_int(alg)
|
||||
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
||||
key.shape.assert_is_compatible_with([])
|
||||
key = array_ops.reshape(key, [1])
|
||||
state = array_ops.concat([counter, key], 0)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
def __init__(self, copy_from=None, state=None, alg=None):
|
||||
"""Creates a generator.
|
||||
|
||||
@ -346,24 +447,26 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
The string names `"philox"` and `"threefry"` can also be used.
|
||||
Note `PHILOX` guarantees the same numbers are produced (given
|
||||
the same random state) across all architectures (CPU, GPU, XLA etc).
|
||||
|
||||
Throws:
|
||||
ValueError: if the generator is created inside a synchronous
|
||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
||||
because there is ambiguity on how to replicate a generator (e.g. should
|
||||
it be copied so such each replica will get the same random numbers, or
|
||||
should it be "split" into different generators that generate
|
||||
different random numbers).
|
||||
"""
|
||||
# TODO(b/175072242): Remove distribution-strategy dependencies in this file.
|
||||
if ds_context.has_strategy():
|
||||
self._distribution_strategy = ds_context.get_strategy()
|
||||
else:
|
||||
self._distribution_strategy = None
|
||||
if copy_from is not None:
|
||||
# All other arguments should be None
|
||||
assert (alg or state) is None
|
||||
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._alg = copy_from.algorithm
|
||||
|
||||
else:
|
||||
assert alg is not None and state is not None
|
||||
if ds_context.has_strategy():
|
||||
strat_name = type(ds_context.get_strategy()).__name__
|
||||
# TODO(b/174610856): Support CentralStorageStrategy and
|
||||
# ParameterServerStrategy.
|
||||
if "CentralStorage" in strat_name or "ParameterServer" in strat_name:
|
||||
raise ValueError("%s is not supported yet" % strat_name)
|
||||
alg = _convert_alg_to_int(alg)
|
||||
if isinstance(state, variables.Variable):
|
||||
_check_state_shape(state.shape, alg)
|
||||
@ -376,7 +479,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
self._alg = alg
|
||||
|
||||
def _create_variable(self, *args, **kwargs):
|
||||
"""Creates a variable, and check that it's not MirroredVariable.
|
||||
"""Creates a variable.
|
||||
|
||||
Args:
|
||||
*args: positional arguments passed along to `variables.Variable.
|
||||
@ -385,135 +488,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
Returns:
|
||||
The created variable.
|
||||
"""
|
||||
if ds_context.has_strategy():
|
||||
raise ValueError(
|
||||
"Creating a generator within a strategy scope is disallowed, because "
|
||||
"there is ambiguity on how to replicate a generator (e.g. should it "
|
||||
"be copied so that each replica gets the same random numbers, or "
|
||||
"'split' so that each replica gets different random numbers).")
|
||||
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
|
||||
var = variables.Variable(*args, **kwargs)
|
||||
return var
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state, alg):
|
||||
"""Creates a generator from a state.
|
||||
|
||||
See `__init__` for description of `state` and `alg`.
|
||||
|
||||
Args:
|
||||
state: the new state.
|
||||
alg: the RNG algorithm.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
|
||||
Throws:
|
||||
ValueError: if the generator is created inside a synchronous
|
||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
||||
because there is ambiguity on how to replicate a generator (e.g. should
|
||||
it be copied so such each replica will get the same random numbers, or
|
||||
should it be "split" into different generators that generate
|
||||
different random numbers).
|
||||
"""
|
||||
return cls(alg=alg, state=state)
|
||||
|
||||
@classmethod
|
||||
def from_seed(cls, seed, alg=None):
|
||||
"""Creates a generator from a seed.
|
||||
|
||||
A seed is a 1024-bit unsigned integer represented either as a Python
|
||||
integer or a vector of integers. Seeds shorter than 1024-bit will be
|
||||
padded. The padding, the internal structure of a seed and the way a seed
|
||||
is converted to a state are all opaque (unspecified). The only semantics
|
||||
specification of seeds is that two different seeds are likely to produce
|
||||
two independent generators (but no guarantee).
|
||||
|
||||
Args:
|
||||
seed: the seed for the RNG.
|
||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
|
||||
Throws:
|
||||
ValueError: if the generator is created inside a synchronous
|
||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
||||
because there is ambiguity on how to replicate a generator (e.g. should
|
||||
it be copied so such each replica will get the same random numbers, or
|
||||
should it be "split" into different generators that generate
|
||||
different random numbers).
|
||||
"""
|
||||
if alg is None:
|
||||
# TODO(wangpeng): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
alg = _convert_alg_to_int(alg)
|
||||
state = create_rng_state(seed, alg)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
@classmethod
|
||||
def from_non_deterministic_state(cls, alg=None):
|
||||
"""Creates a generator by non-deterministically initializing its state.
|
||||
|
||||
The source of the non-determinism will be platform- and time-dependent.
|
||||
|
||||
Args:
|
||||
alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
|
||||
Throws:
|
||||
ValueError: if the generator is created inside a synchronous
|
||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
||||
because there is ambiguity on how to replicate a generator (e.g. should
|
||||
it be copied so such each replica will get the same random numbers, or
|
||||
should it be "split" into different generators that generate
|
||||
different random numbers).
|
||||
"""
|
||||
if alg is None:
|
||||
# TODO(wangpeng): more sophisticated algorithm selection
|
||||
alg = DEFAULT_ALGORITHM
|
||||
alg = _convert_alg_to_int(alg)
|
||||
state = non_deterministic_ints(shape=[_get_state_size(alg)],
|
||||
dtype=SEED_TYPE)
|
||||
return cls(state=state, alg=alg)
|
||||
|
||||
@classmethod
|
||||
def from_key_counter(cls, key, counter, alg):
|
||||
"""Creates a generator from a key and a counter.
|
||||
|
||||
This constructor only applies if the algorithm is a counter-based algorithm.
|
||||
See method `key` for the meaning of "key" and "counter".
|
||||
|
||||
Args:
|
||||
key: the key for the RNG, a scalar of type STATE_TYPE.
|
||||
counter: a vector of dtype STATE_TYPE representing the initial counter for
|
||||
the RNG, whose length is algorithm-specific.,
|
||||
alg: the RNG algorithm. If None, it will be auto-selected. See
|
||||
`__init__` for its possible values.
|
||||
|
||||
Returns:
|
||||
The new generator.
|
||||
|
||||
Throws:
|
||||
ValueError: if the generator is created inside a synchronous
|
||||
`tf.distribute` strategy such as `MirroredStrategy` or `TPUStrategy`,
|
||||
because there is ambiguity on how to replicate a generator (e.g. should
|
||||
it be copied so such each replica will get the same random numbers, or
|
||||
should it be "split" into different generators that generate
|
||||
different random numbers).
|
||||
"""
|
||||
counter = _convert_to_state_tensor(counter)
|
||||
key = _convert_to_state_tensor(key)
|
||||
alg = _convert_alg_to_int(alg)
|
||||
counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
|
||||
key.shape.assert_is_compatible_with([])
|
||||
key = array_ops.reshape(key, [1])
|
||||
state = array_ops.concat([counter, key], 0)
|
||||
return cls(state=state, alg=alg)
|
||||
return variables.Variable(*args, **kwargs)
|
||||
|
||||
def reset(self, state):
|
||||
"""Resets the generator by a new state.
|
||||
@ -556,11 +531,6 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
state = array_ops.concat([counter, key], 0)
|
||||
self._state_var.assign(state)
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return GeneratorSpec(shape=self.state.shape, dtype=self.state.dtype,
|
||||
alg=self.algorithm)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""The internal state of the RNG."""
|
||||
@ -614,15 +584,52 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
counter is an unspecified implementation detail.
|
||||
"""
|
||||
if compat.forward_compatible(2020, 10, 25):
|
||||
return gen_stateful_random_ops.rng_read_and_skip(
|
||||
self.state.handle,
|
||||
alg=math_ops.cast(self.algorithm, dtypes.int32),
|
||||
delta=math_ops.cast(delta, dtypes.uint64))
|
||||
return self._skip(delta)
|
||||
gen_stateful_random_ops.rng_skip(
|
||||
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
|
||||
math_ops.cast(delta, dtypes.int64))
|
||||
# pylint: enable=g-doc-return-or-yield
|
||||
|
||||
def _skip_single_var(self, var, delta):
|
||||
# TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
|
||||
return gen_stateful_random_ops.rng_read_and_skip(
|
||||
var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32),
|
||||
delta=math_ops.cast(delta, dtypes.uint64))
|
||||
|
||||
def _skip(self, delta):
|
||||
def update_fn(v):
|
||||
return self._skip_single_var(v, delta)
|
||||
# TODO(b/170515001): Always call strategy.extended.update after calling it
|
||||
# from both replica context and cross-replica context is supported.
|
||||
if values_util.is_saving_non_distributed():
|
||||
# Assumes replica context with replica_id=0, since we only save the first
|
||||
# replica.
|
||||
return update_fn(self.state)
|
||||
if self._distribution_strategy is not None:
|
||||
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
|
||||
if ds_context.in_cross_replica_context():
|
||||
# Code that operates on all replicas of a variable cannot be saved
|
||||
# without retracing.
|
||||
values_util.mark_as_unsaveable()
|
||||
# In cross-replica context we need to use strategy.extended.update.
|
||||
return ds_context.get_strategy().extended.update(
|
||||
self.state, update_fn)
|
||||
return update_fn(self.state)
|
||||
|
||||
def _preprocess_key(self, key):
|
||||
if self._distribution_strategy is None:
|
||||
return key
|
||||
with ds_context.enter_or_assert_strategy(self._distribution_strategy):
|
||||
replica_id = get_replica_id()
|
||||
if replica_id is not None:
|
||||
replica_id = array_ops.stack([replica_id, 0], axis=0)
|
||||
replica_id = math_ops.cast(replica_id, dtypes.uint64)
|
||||
# Conceptually: key = hash(key, replica_id)
|
||||
key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
|
||||
shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
|
||||
alg=self.algorithm)
|
||||
return key
|
||||
|
||||
def _prepare_key_counter(self, shape):
|
||||
delta = math_ops.reduce_prod(shape)
|
||||
counter_key = self.skip(delta)
|
||||
@ -630,6 +637,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
|
||||
key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
|
||||
dtypes.uint64)
|
||||
key = self._preprocess_key(key)
|
||||
return key, counter
|
||||
|
||||
# The following functions return a tensor and as a side effect update
|
||||
|
@ -18,11 +18,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.distribute import values as dist_values
|
||||
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
||||
@ -33,7 +33,6 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.kernel_tests.random import util as \
|
||||
random_test_util
|
||||
@ -45,6 +44,7 @@ from tensorflow.python.ops import stateful_random_ops as \
|
||||
random
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.tracking import util as tracking_util
|
||||
|
||||
|
||||
g_seeded = None
|
||||
@ -636,49 +636,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(res1, res2)
|
||||
self.assertAllEqual(g1.state.read_value(), g2.state.read_value())
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testFunArgAlgIsInt(self):
|
||||
"""Tests that `algorithm` is `int` when reconstructed from composite tensor.
|
||||
"""
|
||||
@def_function.function
|
||||
def f(g):
|
||||
self.assertIsInstance(g.algorithm, six.integer_types)
|
||||
return g.make_seeds(), g
|
||||
gen = random.Generator.from_seed(123, alg="philox")
|
||||
f(gen)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testLimitedRetracingWithCompositeTensors(self):
|
||||
"""Tests that RNGs with the same shape/dtype won't cause retracing.
|
||||
"""
|
||||
trace_count = [0]
|
||||
|
||||
@def_function.function
|
||||
def f(x):
|
||||
trace_count[0] += 1
|
||||
return x.normal([])
|
||||
|
||||
f(random.Generator.from_seed(1))
|
||||
f(random.Generator.from_seed(2))
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
def testMostSpecificCompatibleType(self):
|
||||
"""Tests GeneratorSpec.most_specific_compatible_type.
|
||||
"""
|
||||
spec = random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32)
|
||||
res = spec.most_specific_compatible_type(
|
||||
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int32))
|
||||
self.assertEqual(spec, res)
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
||||
spec.most_specific_compatible_type(
|
||||
tensor_spec.TensorSpec(shape=(2, 3), dtype=dtypes.int32))
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
||||
spec.most_specific_compatible_type(
|
||||
random.GeneratorSpec(shape=(2, 4), dtype=dtypes.int32))
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, ""):
|
||||
spec.most_specific_compatible_type(
|
||||
random.GeneratorSpec(shape=(2, 3), dtype=dtypes.int64))
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testCreateOutsideMirroredStrat(self):
|
||||
"""Tests RNG/MirrorStrategy interaction #1.
|
||||
@ -702,29 +659,6 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(2, len(values))
|
||||
self.assertAllDifferent(values)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testMirroredStratParaSyncDisallowed(self):
|
||||
"""Tests that generator creation in MirroredStrategy is disallowed.
|
||||
"""
|
||||
creators = [
|
||||
lambda: random.Generator.from_seed(1234),
|
||||
random.Generator.from_non_deterministic_state,
|
||||
]
|
||||
shape = [3, 4]
|
||||
dtype = dtypes.int32
|
||||
strat = MirroredStrategy(devices=["cpu:0", "cpu:1"])
|
||||
for creator in creators:
|
||||
with strat.scope():
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError, "disallowed"):
|
||||
creator() # pylint: disable=cell-var-from-loop
|
||||
def f():
|
||||
gen = creator() # pylint: disable=cell-var-from-loop
|
||||
return gen.uniform_full_int(shape=shape, dtype=dtype)
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
ValueError, "disallowed"):
|
||||
strat.extended.call_for_each_replica(fn=f)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testMirroredStratParaAsync(self):
|
||||
"""Tests RNG/MirrorStrategy interaction #2.
|
||||
@ -764,6 +698,23 @@ class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
|
||||
r2 = g.uniform_full_int(shape=shape, dtype=dtype)
|
||||
self.assertAllEqual(r1, r2)
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testRestore(self):
|
||||
"""Tests save and restore.
|
||||
"""
|
||||
fname = os.path.join(self.get_temp_dir(), "checkpoint")
|
||||
g = random.Generator.from_seed(1)
|
||||
cp = tracking_util.Checkpoint(g=g)
|
||||
def write_restore_compare():
|
||||
cp.write(fname)
|
||||
r1 = g.uniform([], dtype=dtypes.uint32, minval=None)
|
||||
cp.restore(fname)
|
||||
r2 = g.uniform([], dtype=dtypes.uint32, minval=None)
|
||||
self.assertAllEqual(r1, r2)
|
||||
# Run multiple times so that cp.write is called in various RNG states
|
||||
for _ in range(2):
|
||||
write_restore_compare()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.set_soft_device_placement(False)
|
||||
|
@ -3,7 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "algorithm"
|
||||
|
@ -3,7 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "algorithm"
|
||||
|
@ -3,7 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "algorithm"
|
||||
|
@ -3,7 +3,6 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.stateful_random_ops.Generator\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "algorithm"
|
||||
|
Loading…
Reference in New Issue
Block a user