Added a sub-class of tf.random.Generator in keras/layers/preprocessing/image_preprocessing.py to temporarily allow creating generators inside distribution strategies (with all replicas generating the same numbers).
PiperOrigin-RevId: 315586245 Change-Id: If6d80d82aa5ba1828c8c00c7aef35ba12a871694
This commit is contained in:
parent
894f1324dd
commit
b8cce8c2c4
@ -36,6 +36,7 @@ from tensorflow.python.ops import image_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import stateful_random_ops
|
from tensorflow.python.ops import stateful_random_ops
|
||||||
from tensorflow.python.ops import stateless_random_ops
|
from tensorflow.python.ops import stateless_random_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
ResizeMethod = image_ops.ResizeMethod
|
ResizeMethod = image_ops.ResizeMethod
|
||||||
@ -1292,11 +1293,32 @@ class RandomWidth(Layer):
|
|||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/147877541, b/158339556): This class is added to temporarily enable
|
||||||
|
# creating generators within distribution strategies. Remove it when the proper
|
||||||
|
# API is in place.
|
||||||
|
class _RandomGenerator(stateful_random_ops.Generator):
|
||||||
|
"""A subclass that allows creation inside distribution strategies.
|
||||||
|
|
||||||
|
This is a temporary solution to allow creating tf.random.Generator inside
|
||||||
|
distribution strategies. It will be removed when proper API is in place.
|
||||||
|
|
||||||
|
All replicas will have the same RNG state and generate the same random
|
||||||
|
numbers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _create_variable(self, *args, **kwargs):
|
||||||
|
# This function does the same thing as the base class's namesake, except
|
||||||
|
# that it skips the distribution-strategy check. When we are inside a
|
||||||
|
# distribution-strategy scope, variables.Variable will pick a proper
|
||||||
|
# variable class (e.g. MirroredVariable).
|
||||||
|
return variables.Variable(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_generator(seed=None):
|
def make_generator(seed=None):
|
||||||
if seed:
|
if seed:
|
||||||
return stateful_random_ops.Generator.from_seed(seed)
|
return _RandomGenerator.from_seed(seed)
|
||||||
else:
|
else:
|
||||||
return stateful_random_ops.Generator.from_non_deterministic_state()
|
return _RandomGenerator.from_non_deterministic_state()
|
||||||
|
|
||||||
|
|
||||||
def get_interpolation(interpolation):
|
def get_interpolation(interpolation):
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
@ -962,6 +963,21 @@ class RandomRotationTest(keras_parameterized.TestCase):
|
|||||||
actual_output = layer(input_images, training=0)
|
actual_output = layer(input_images, training=0)
|
||||||
self.assertAllClose(expected_output, actual_output)
|
self.assertAllClose(expected_output, actual_output)
|
||||||
|
|
||||||
|
def test_distribution_strategy(self):
|
||||||
|
"""Tests that RandomRotation can be created within distribution strategies.
|
||||||
|
|
||||||
|
And that replicas got the same random result.
|
||||||
|
"""
|
||||||
|
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||||
|
with tf_test_util.use_gpu():
|
||||||
|
strat = MirroredStrategy(devices=['cpu', 'gpu'])
|
||||||
|
with strat.scope():
|
||||||
|
layer = image_preprocessing.RandomRotation(.5)
|
||||||
|
output = strat.run(lambda: layer(input_images, training=True))
|
||||||
|
values = output.values
|
||||||
|
self.assertAllEqual(2, len(values))
|
||||||
|
self.assertAllClose(values[0], values[1])
|
||||||
|
|
||||||
@tf_test_util.run_v2_only
|
@tf_test_util.run_v2_only
|
||||||
def test_config_with_custom_name(self):
|
def test_config_with_custom_name(self):
|
||||||
layer = image_preprocessing.RandomRotation(.5, name='image_preproc')
|
layer = image_preprocessing.RandomRotation(.5, name='image_preproc')
|
||||||
|
@ -255,27 +255,6 @@ class GeneratorSpec(type_spec.TypeSpec):
|
|||||||
return (self.shape, self.dtype, self.alg)
|
return (self.shape, self.dtype, self.alg)
|
||||||
|
|
||||||
|
|
||||||
def _create_variable(*args, **kwargs):
|
|
||||||
"""Creates a variable, and check that it's not MirroredVariable.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: positional arguments passed along to `variables.Variable.
|
|
||||||
**kwargs: keyword arguments passed along to `variables.Variable.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The created variable.
|
|
||||||
"""
|
|
||||||
if ds_context.has_strategy():
|
|
||||||
raise ValueError(
|
|
||||||
"Creating a generator within a strategy scope is disallowed, because "
|
|
||||||
"there is ambiguity on how to replicate a generator (e.g. should it be "
|
|
||||||
"copied so that each replica gets the same random numbers, or 'split' "
|
|
||||||
"so that each replica gets different random numbers).")
|
|
||||||
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
|
|
||||||
var = variables.Variable(*args, **kwargs)
|
|
||||||
return var
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("random.Generator", "random.experimental.Generator")
|
@tf_export("random.Generator", "random.experimental.Generator")
|
||||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
||||||
"""Random-number generator.
|
"""Random-number generator.
|
||||||
@ -367,7 +346,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
if copy_from is not None:
|
if copy_from is not None:
|
||||||
# All other arguments should be None
|
# All other arguments should be None
|
||||||
assert (alg or state) is None
|
assert (alg or state) is None
|
||||||
self._state_var = _create_variable(copy_from.state, dtype=STATE_TYPE,
|
self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
|
||||||
trainable=False)
|
trainable=False)
|
||||||
self._alg = copy_from.algorithm
|
self._alg = copy_from.algorithm
|
||||||
|
|
||||||
@ -380,10 +359,30 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
|
|||||||
else:
|
else:
|
||||||
state = _convert_to_state_tensor(state)
|
state = _convert_to_state_tensor(state)
|
||||||
_check_state_shape(state.shape, alg)
|
_check_state_shape(state.shape, alg)
|
||||||
self._state_var = _create_variable(state, dtype=STATE_TYPE,
|
self._state_var = self._create_variable(state, dtype=STATE_TYPE,
|
||||||
trainable=False)
|
trainable=False)
|
||||||
self._alg = alg
|
self._alg = alg
|
||||||
|
|
||||||
|
def _create_variable(self, *args, **kwargs):
|
||||||
|
"""Creates a variable, and check that it's not MirroredVariable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: positional arguments passed along to `variables.Variable.
|
||||||
|
**kwargs: keyword arguments passed along to `variables.Variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created variable.
|
||||||
|
"""
|
||||||
|
if ds_context.has_strategy():
|
||||||
|
raise ValueError(
|
||||||
|
"Creating a generator within a strategy scope is disallowed, because "
|
||||||
|
"there is ambiguity on how to replicate a generator (e.g. should it "
|
||||||
|
"be copied so that each replica gets the same random numbers, or "
|
||||||
|
"'split' so that each replica gets different random numbers).")
|
||||||
|
# TODO(wangpeng): Link to the RNG guide for solutions in such cases.
|
||||||
|
var = variables.Variable(*args, **kwargs)
|
||||||
|
return var
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(cls, state, alg):
|
def from_state(cls, state, alg):
|
||||||
"""Creates a generator from a state.
|
"""Creates a generator from a state.
|
||||||
|
Loading…
Reference in New Issue
Block a user