Added a sub-class of tf.random.Generator in keras/layers/preprocessing/image_preprocessing.py to temporarily allow creating generators inside distribution strategies (with all replicas generating the same numbers).
PiperOrigin-RevId: 315586245 Change-Id: If6d80d82aa5ba1828c8c00c7aef35ba12a871694
This commit is contained in:
		
							parent
							
								
									894f1324dd
								
							
						
					
					
						commit
						b8cce8c2c4
					
				@ -36,6 +36,7 @@ from tensorflow.python.ops import image_ops
 | 
				
			|||||||
from tensorflow.python.ops import math_ops
 | 
					from tensorflow.python.ops import math_ops
 | 
				
			||||||
from tensorflow.python.ops import stateful_random_ops
 | 
					from tensorflow.python.ops import stateful_random_ops
 | 
				
			||||||
from tensorflow.python.ops import stateless_random_ops
 | 
					from tensorflow.python.ops import stateless_random_ops
 | 
				
			||||||
 | 
					from tensorflow.python.ops import variables
 | 
				
			||||||
from tensorflow.python.util.tf_export import keras_export
 | 
					from tensorflow.python.util.tf_export import keras_export
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ResizeMethod = image_ops.ResizeMethod
 | 
					ResizeMethod = image_ops.ResizeMethod
 | 
				
			||||||
@ -1292,11 +1293,32 @@ class RandomWidth(Layer):
 | 
				
			|||||||
    return dict(list(base_config.items()) + list(config.items()))
 | 
					    return dict(list(base_config.items()) + list(config.items()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO(b/147877541, b/158339556): This class is added to temporarily enable
 | 
				
			||||||
 | 
					# creating generators within distribution strategies. Remove it when the proper
 | 
				
			||||||
 | 
					# API is in place.
 | 
				
			||||||
 | 
					class _RandomGenerator(stateful_random_ops.Generator):
 | 
				
			||||||
 | 
					  """A subclass that allows creation inside distribution strategies.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  This is a temporary solution to allow creating tf.random.Generator inside
 | 
				
			||||||
 | 
					  distribution strategies. It will be removed when proper API is in place.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  All replicas will have the same RNG state and generate the same random
 | 
				
			||||||
 | 
					  numbers.
 | 
				
			||||||
 | 
					  """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_variable(self, *args, **kwargs):
 | 
				
			||||||
 | 
					    # This function does the same thing as the base class's namesake, except
 | 
				
			||||||
 | 
					    # that it skips the distribution-strategy check. When we are inside a
 | 
				
			||||||
 | 
					    # distribution-strategy scope, variables.Variable will pick a proper
 | 
				
			||||||
 | 
					    # variable class (e.g. MirroredVariable).
 | 
				
			||||||
 | 
					    return variables.Variable(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_generator(seed=None):
 | 
					def make_generator(seed=None):
 | 
				
			||||||
  if seed:
 | 
					  if seed:
 | 
				
			||||||
    return stateful_random_ops.Generator.from_seed(seed)
 | 
					    return _RandomGenerator.from_seed(seed)
 | 
				
			||||||
  else:
 | 
					  else:
 | 
				
			||||||
    return stateful_random_ops.Generator.from_non_deterministic_state()
 | 
					    return _RandomGenerator.from_non_deterministic_state()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_interpolation(interpolation):
 | 
					def get_interpolation(interpolation):
 | 
				
			||||||
 | 
				
			|||||||
@ -21,6 +21,7 @@ from __future__ import print_function
 | 
				
			|||||||
from absl.testing import parameterized
 | 
					from absl.testing import parameterized
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy
 | 
				
			||||||
from tensorflow.python.framework import errors
 | 
					from tensorflow.python.framework import errors
 | 
				
			||||||
from tensorflow.python.framework import test_util as tf_test_util
 | 
					from tensorflow.python.framework import test_util as tf_test_util
 | 
				
			||||||
from tensorflow.python.keras import keras_parameterized
 | 
					from tensorflow.python.keras import keras_parameterized
 | 
				
			||||||
@ -962,6 +963,21 @@ class RandomRotationTest(keras_parameterized.TestCase):
 | 
				
			|||||||
        actual_output = layer(input_images, training=0)
 | 
					        actual_output = layer(input_images, training=0)
 | 
				
			||||||
        self.assertAllClose(expected_output, actual_output)
 | 
					        self.assertAllClose(expected_output, actual_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_distribution_strategy(self):
 | 
				
			||||||
 | 
					    """Tests that RandomRotation can be created within distribution strategies.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    And that replicas got the same random result.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
 | 
				
			||||||
 | 
					    with tf_test_util.use_gpu():
 | 
				
			||||||
 | 
					      strat = MirroredStrategy(devices=['cpu', 'gpu'])
 | 
				
			||||||
 | 
					      with strat.scope():
 | 
				
			||||||
 | 
					        layer = image_preprocessing.RandomRotation(.5)
 | 
				
			||||||
 | 
					        output = strat.run(lambda: layer(input_images, training=True))
 | 
				
			||||||
 | 
					      values = output.values
 | 
				
			||||||
 | 
					      self.assertAllEqual(2, len(values))
 | 
				
			||||||
 | 
					      self.assertAllClose(values[0], values[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @tf_test_util.run_v2_only
 | 
					  @tf_test_util.run_v2_only
 | 
				
			||||||
  def test_config_with_custom_name(self):
 | 
					  def test_config_with_custom_name(self):
 | 
				
			||||||
    layer = image_preprocessing.RandomRotation(.5, name='image_preproc')
 | 
					    layer = image_preprocessing.RandomRotation(.5, name='image_preproc')
 | 
				
			||||||
 | 
				
			|||||||
@ -255,27 +255,6 @@ class GeneratorSpec(type_spec.TypeSpec):
 | 
				
			|||||||
    return (self.shape, self.dtype, self.alg)
 | 
					    return (self.shape, self.dtype, self.alg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _create_variable(*args, **kwargs):
 | 
					 | 
				
			||||||
  """Creates a variable, and check that it's not MirroredVariable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Args:
 | 
					 | 
				
			||||||
    *args: positional arguments passed along to `variables.Variable.
 | 
					 | 
				
			||||||
    **kwargs: keyword arguments passed along to `variables.Variable.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  Returns:
 | 
					 | 
				
			||||||
    The created variable.
 | 
					 | 
				
			||||||
  """
 | 
					 | 
				
			||||||
  if ds_context.has_strategy():
 | 
					 | 
				
			||||||
    raise ValueError(
 | 
					 | 
				
			||||||
        "Creating a generator within a strategy scope is disallowed, because "
 | 
					 | 
				
			||||||
        "there is ambiguity on how to replicate a generator (e.g. should it be "
 | 
					 | 
				
			||||||
        "copied so that each replica gets the same random numbers, or 'split' "
 | 
					 | 
				
			||||||
        "so that each replica gets different random numbers).")
 | 
					 | 
				
			||||||
    # TODO(wangpeng): Link to the RNG guide for solutions in such cases.
 | 
					 | 
				
			||||||
  var = variables.Variable(*args, **kwargs)
 | 
					 | 
				
			||||||
  return var
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@tf_export("random.Generator", "random.experimental.Generator")
 | 
					@tf_export("random.Generator", "random.experimental.Generator")
 | 
				
			||||||
class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
 | 
					class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
 | 
				
			||||||
  """Random-number generator.
 | 
					  """Random-number generator.
 | 
				
			||||||
@ -367,7 +346,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
 | 
				
			|||||||
    if copy_from is not None:
 | 
					    if copy_from is not None:
 | 
				
			||||||
      # All other arguments should be None
 | 
					      # All other arguments should be None
 | 
				
			||||||
      assert (alg or state) is None
 | 
					      assert (alg or state) is None
 | 
				
			||||||
      self._state_var = _create_variable(copy_from.state, dtype=STATE_TYPE,
 | 
					      self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
 | 
				
			||||||
                                              trainable=False)
 | 
					                                              trainable=False)
 | 
				
			||||||
      self._alg = copy_from.algorithm
 | 
					      self._alg = copy_from.algorithm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -380,10 +359,30 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor):
 | 
				
			|||||||
      else:
 | 
					      else:
 | 
				
			||||||
        state = _convert_to_state_tensor(state)
 | 
					        state = _convert_to_state_tensor(state)
 | 
				
			||||||
        _check_state_shape(state.shape, alg)
 | 
					        _check_state_shape(state.shape, alg)
 | 
				
			||||||
        self._state_var = _create_variable(state, dtype=STATE_TYPE,
 | 
					        self._state_var = self._create_variable(state, dtype=STATE_TYPE,
 | 
				
			||||||
                                                trainable=False)
 | 
					                                                trainable=False)
 | 
				
			||||||
      self._alg = alg
 | 
					      self._alg = alg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _create_variable(self, *args, **kwargs):
 | 
				
			||||||
 | 
					    """Creates a variable, and check that it's not MirroredVariable.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					      *args: positional arguments passed along to `variables.Variable.
 | 
				
			||||||
 | 
					      **kwargs: keyword arguments passed along to `variables.Variable.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					      The created variable.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if ds_context.has_strategy():
 | 
				
			||||||
 | 
					      raise ValueError(
 | 
				
			||||||
 | 
					          "Creating a generator within a strategy scope is disallowed, because "
 | 
				
			||||||
 | 
					          "there is ambiguity on how to replicate a generator (e.g. should it "
 | 
				
			||||||
 | 
					          "be copied so that each replica gets the same random numbers, or "
 | 
				
			||||||
 | 
					          "'split' so that each replica gets different random numbers).")
 | 
				
			||||||
 | 
					      # TODO(wangpeng): Link to the RNG guide for solutions in such cases.
 | 
				
			||||||
 | 
					    var = variables.Variable(*args, **kwargs)
 | 
				
			||||||
 | 
					    return var
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @classmethod
 | 
					  @classmethod
 | 
				
			||||||
  def from_state(cls, state, alg):
 | 
					  def from_state(cls, state, alg):
 | 
				
			||||||
    """Creates a generator from a state.
 | 
					    """Creates a generator from a state.
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user