Added a sub-class of tf.random.Generator in keras/layers/preprocessing/image_preprocessing.py to temporarily allow creating generators inside distribution strategies (with all replicas generating the same numbers).

PiperOrigin-RevId: 315586245
Change-Id: If6d80d82aa5ba1828c8c00c7aef35ba12a871694
This commit is contained in:
Peng Wang 2020-06-09 16:40:47 -07:00 committed by TensorFlower Gardener
parent 894f1324dd
commit b8cce8c2c4
3 changed files with 64 additions and 27 deletions

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateful_random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.util.tf_export import keras_export
ResizeMethod = image_ops.ResizeMethod
@ -1292,11 +1293,32 @@ class RandomWidth(Layer):
return dict(list(base_config.items()) + list(config.items()))
# TODO(b/147877541, b/158339556): This class is added to temporarily enable
# creating generators within distribution strategies. Remove it when the proper
# API is in place.
class _RandomGenerator(stateful_random_ops.Generator):
"""A subclass that allows creation inside distribution strategies.
This is a temporary solution to allow creating tf.random.Generator inside
distribution strategies. It will be removed when proper API is in place.
All replicas will have the same RNG state and generate the same random
numbers.
"""
def _create_variable(self, *args, **kwargs):
# This function does the same thing as the base class's namesake, except
# that it skips the distribution-strategy check. When we are inside a
# distribution-strategy scope, variables.Variable will pick a proper
# variable class (e.g. MirroredVariable).
return variables.Variable(*args, **kwargs)
def make_generator(seed=None):
if seed:
return stateful_random_ops.Generator.from_seed(seed)
return _RandomGenerator.from_seed(seed)
else:
return stateful_random_ops.Generator.from_non_deterministic_state()
return _RandomGenerator.from_non_deterministic_state()
def get_interpolation(interpolation):

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
@ -962,6 +963,21 @@ class RandomRotationTest(keras_parameterized.TestCase):
actual_output = layer(input_images, training=0)
self.assertAllClose(expected_output, actual_output)
def test_distribution_strategy(self):
"""Tests that RandomRotation can be created within distribution strategies.
And that replicas got the same random result.
"""
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
with tf_test_util.use_gpu():
strat = MirroredStrategy(devices=['cpu', 'gpu'])
with strat.scope():
layer = image_preprocessing.RandomRotation(.5)
output = strat.run(lambda: layer(input_images, training=True))
values = output.values
self.assertAllEqual(2, len(values))
self.assertAllClose(values[0], values[1])
@tf_test_util.run_v2_only
def test_config_with_custom_name(self):
layer = image_preprocessing.RandomRotation(.5, name='image_preproc')

View File

@ -255,27 +255,6 @@ class GeneratorSpec(type_spec.TypeSpec):
return (self.shape, self.dtype, self.alg)
def _create_variable(*args, **kwargs):
"""Creates a variable, and check that it's not MirroredVariable.
Args:
*args: positional arguments passed along to `variables.Variable.
**kwargs: keyword arguments passed along to `variables.Variable.
Returns:
The created variable.
"""
if ds_context.has_strategy():
raise ValueError(
"Creating a generator within a strategy scope is disallowed, because "
"there is ambiguity on how to replicate a generator (e.g. should it be "
"copied so that each replica gets the same random numbers, or 'split' "
"so that each replica gets different random numbers).")
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
var = variables.Variable(*args, **kwargs)
return var
@tf_export("random.Generator", "random.experimental.Generator")
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
"""Random-number generator.
@ -367,8 +346,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
if copy_from is not None:
# All other arguments should be None
assert (alg or state) is None
self._state_var = _create_variable(copy_from.state, dtype=STATE_TYPE,
trainable=False)
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
trainable=False)
self._alg = copy_from.algorithm
else:
@ -380,10 +359,30 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
else:
state = _convert_to_state_tensor(state)
_check_state_shape(state.shape, alg)
self._state_var = _create_variable(state, dtype=STATE_TYPE,
trainable=False)
self._state_var = self._create_variable(state, dtype=STATE_TYPE,
trainable=False)
self._alg = alg
def _create_variable(self, *args, **kwargs):
"""Creates a variable, and check that it's not MirroredVariable.
Args:
*args: positional arguments passed along to `variables.Variable.
**kwargs: keyword arguments passed along to `variables.Variable.
Returns:
The created variable.
"""
if ds_context.has_strategy():
raise ValueError(
"Creating a generator within a strategy scope is disallowed, because "
"there is ambiguity on how to replicate a generator (e.g. should it "
"be copied so that each replica gets the same random numbers, or "
"'split' so that each replica gets different random numbers).")
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
var = variables.Variable(*args, **kwargs)
return var
@classmethod
def from_state(cls, state, alg):
"""Creates a generator from a state.