Added a sub-class of tf.random.Generator in keras/layers/preprocessing/image_preprocessing.py to temporarily allow creating generators inside distribution strategies (with all replicas generating the same numbers).
PiperOrigin-RevId: 315586245 Change-Id: If6d80d82aa5ba1828c8c00c7aef35ba12a871694
This commit is contained in:
parent
894f1324dd
commit
b8cce8c2c4
@ -36,6 +36,7 @@ from tensorflow.python.ops import image_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import stateful_random_ops
|
||||
from tensorflow.python.ops import stateless_random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
ResizeMethod = image_ops.ResizeMethod
|
||||
@ -1292,11 +1293,32 @@ class RandomWidth(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
# TODO(b/147877541, b/158339556): This class is added to temporarily enable
|
||||
# creating generators within distribution strategies. Remove it when the proper
|
||||
# API is in place.
|
||||
class _RandomGenerator(stateful_random_ops.Generator):
|
||||
"""A subclass that allows creation inside distribution strategies.
|
||||
|
||||
This is a temporary solution to allow creating tf.random.Generator inside
|
||||
distribution strategies. It will be removed when proper API is in place.
|
||||
|
||||
All replicas will have the same RNG state and generate the same random
|
||||
numbers.
|
||||
"""
|
||||
|
||||
def _create_variable(self, *args, **kwargs):
|
||||
# This function does the same thing as the base class's namesake, except
|
||||
# that it skips the distribution-strategy check. When we are inside a
|
||||
# distribution-strategy scope, variables.Variable will pick a proper
|
||||
# variable class (e.g. MirroredVariable).
|
||||
return variables.Variable(*args, **kwargs)
|
||||
|
||||
|
||||
def make_generator(seed=None):
|
||||
if seed:
|
||||
return stateful_random_ops.Generator.from_seed(seed)
|
||||
return _RandomGenerator.from_seed(seed)
|
||||
else:
|
||||
return stateful_random_ops.Generator.from_non_deterministic_state()
|
||||
return _RandomGenerator.from_non_deterministic_state()
|
||||
|
||||
|
||||
def get_interpolation(interpolation):
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
@ -962,6 +963,21 @@ class RandomRotationTest(keras_parameterized.TestCase):
|
||||
actual_output = layer(input_images, training=0)
|
||||
self.assertAllClose(expected_output, actual_output)
|
||||
|
||||
def test_distribution_strategy(self):
|
||||
"""Tests that RandomRotation can be created within distribution strategies.
|
||||
|
||||
And that replicas got the same random result.
|
||||
"""
|
||||
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||
with tf_test_util.use_gpu():
|
||||
strat = MirroredStrategy(devices=['cpu', 'gpu'])
|
||||
with strat.scope():
|
||||
layer = image_preprocessing.RandomRotation(.5)
|
||||
output = strat.run(lambda: layer(input_images, training=True))
|
||||
values = output.values
|
||||
self.assertAllEqual(2, len(values))
|
||||
self.assertAllClose(values[0], values[1])
|
||||
|
||||
@tf_test_util.run_v2_only
|
||||
def test_config_with_custom_name(self):
|
||||
layer = image_preprocessing.RandomRotation(.5, name='image_preproc')
|
||||
|
@ -255,27 +255,6 @@ class GeneratorSpec(type_spec.TypeSpec):
|
||||
return (self.shape, self.dtype, self.alg)
|
||||
|
||||
|
||||
def _create_variable(*args, **kwargs):
|
||||
"""Creates a variable, and check that it's not MirroredVariable.
|
||||
|
||||
Args:
|
||||
*args: positional arguments passed along to `variables.Variable.
|
||||
**kwargs: keyword arguments passed along to `variables.Variable.
|
||||
|
||||
Returns:
|
||||
The created variable.
|
||||
"""
|
||||
if ds_context.has_strategy():
|
||||
raise ValueError(
|
||||
"Creating a generator within a strategy scope is disallowed, because "
|
||||
"there is ambiguity on how to replicate a generator (e.g. should it be "
|
||||
"copied so that each replica gets the same random numbers, or 'split' "
|
||||
"so that each replica gets different random numbers).")
|
||||
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
|
||||
var = variables.Variable(*args, **kwargs)
|
||||
return var
|
||||
|
||||
|
||||
@tf_export("random.Generator", "random.experimental.Generator")
|
||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
"""Random-number generator.
|
||||
@ -367,8 +346,8 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
if copy_from is not None:
|
||||
# All other arguments should be None
|
||||
assert (alg or state) is None
|
||||
self._state_var = _create_variable(copy_from.state, dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._alg = copy_from.algorithm
|
||||
|
||||
else:
|
||||
@ -380,10 +359,30 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||
else:
|
||||
state = _convert_to_state_tensor(state)
|
||||
_check_state_shape(state.shape, alg)
|
||||
self._state_var = _create_variable(state, dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._state_var = self._create_variable(state, dtype=STATE_TYPE,
|
||||
trainable=False)
|
||||
self._alg = alg
|
||||
|
||||
def _create_variable(self, *args, **kwargs):
|
||||
"""Creates a variable, and check that it's not MirroredVariable.
|
||||
|
||||
Args:
|
||||
*args: positional arguments passed along to `variables.Variable.
|
||||
**kwargs: keyword arguments passed along to `variables.Variable.
|
||||
|
||||
Returns:
|
||||
The created variable.
|
||||
"""
|
||||
if ds_context.has_strategy():
|
||||
raise ValueError(
|
||||
"Creating a generator within a strategy scope is disallowed, because "
|
||||
"there is ambiguity on how to replicate a generator (e.g. should it "
|
||||
"be copied so that each replica gets the same random numbers, or "
|
||||
"'split' so that each replica gets different random numbers).")
|
||||
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
|
||||
var = variables.Variable(*args, **kwargs)
|
||||
return var
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state, alg):
|
||||
"""Creates a generator from a state.
|
||||
|
Loading…
Reference in New Issue
Block a user