Added a sub-class of tf.random.Generator in keras/layers/preprocessing/image_preprocessing.py to temporarily allow creating generators inside distribution strategies (with all replicas generating the same numbers).
PiperOrigin-RevId: 315586245 Change-Id: If6d80d82aa5ba1828c8c00c7aef35ba12a871694
This commit is contained in:
		
							parent
							
								
									894f1324dd
								
							
						
					
					
						commit
						b8cce8c2c4
					
				| @ -36,6 +36,7 @@ from tensorflow.python.ops import image_ops | |||||||
| from tensorflow.python.ops import math_ops | from tensorflow.python.ops import math_ops | ||||||
| from tensorflow.python.ops import stateful_random_ops | from tensorflow.python.ops import stateful_random_ops | ||||||
| from tensorflow.python.ops import stateless_random_ops | from tensorflow.python.ops import stateless_random_ops | ||||||
|  | from tensorflow.python.ops import variables | ||||||
| from tensorflow.python.util.tf_export import keras_export | from tensorflow.python.util.tf_export import keras_export | ||||||
| 
 | 
 | ||||||
| ResizeMethod = image_ops.ResizeMethod | ResizeMethod = image_ops.ResizeMethod | ||||||
| @ -1292,11 +1293,32 @@ class RandomWidth(Layer): | |||||||
|     return dict(list(base_config.items()) + list(config.items())) |     return dict(list(base_config.items()) + list(config.items())) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | # TODO(b/147877541, b/158339556): This class is added to temporarily enable | ||||||
|  | # creating generators within distribution strategies. Remove it when the proper | ||||||
|  | # API is in place. | ||||||
|  | class _RandomGenerator(stateful_random_ops.Generator): | ||||||
|  |   """A subclass that allows creation inside distribution strategies. | ||||||
|  | 
 | ||||||
|  |   This is a temporary solution to allow creating tf.random.Generator inside | ||||||
|  |   distribution strategies. It will be removed when proper API is in place. | ||||||
|  | 
 | ||||||
|  |   All replicas will have the same RNG state and generate the same random | ||||||
|  |   numbers. | ||||||
|  |   """ | ||||||
|  | 
 | ||||||
|  |   def _create_variable(self, *args, **kwargs): | ||||||
|  |     # This function does the same thing as the base class's namesake, except | ||||||
|  |     # that it skips the distribution-strategy check. When we are inside a | ||||||
|  |     # distribution-strategy scope, variables.Variable will pick a proper | ||||||
|  |     # variable class (e.g. MirroredVariable). | ||||||
|  |     return variables.Variable(*args, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def make_generator(seed=None): | def make_generator(seed=None): | ||||||
|   if seed: |   if seed: | ||||||
|     return stateful_random_ops.Generator.from_seed(seed) |     return _RandomGenerator.from_seed(seed) | ||||||
|   else: |   else: | ||||||
|     return stateful_random_ops.Generator.from_non_deterministic_state() |     return _RandomGenerator.from_non_deterministic_state() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_interpolation(interpolation): | def get_interpolation(interpolation): | ||||||
|  | |||||||
| @ -21,6 +21,7 @@ from __future__ import print_function | |||||||
| from absl.testing import parameterized | from absl.testing import parameterized | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
|  | from tensorflow.python.distribute.mirrored_strategy import MirroredStrategy | ||||||
| from tensorflow.python.framework import errors | from tensorflow.python.framework import errors | ||||||
| from tensorflow.python.framework import test_util as tf_test_util | from tensorflow.python.framework import test_util as tf_test_util | ||||||
| from tensorflow.python.keras import keras_parameterized | from tensorflow.python.keras import keras_parameterized | ||||||
| @ -962,6 +963,21 @@ class RandomRotationTest(keras_parameterized.TestCase): | |||||||
|         actual_output = layer(input_images, training=0) |         actual_output = layer(input_images, training=0) | ||||||
|         self.assertAllClose(expected_output, actual_output) |         self.assertAllClose(expected_output, actual_output) | ||||||
| 
 | 
 | ||||||
|  |   def test_distribution_strategy(self): | ||||||
|  |     """Tests that RandomRotation can be created within distribution strategies. | ||||||
|  | 
 | ||||||
|  |     And that replicas got the same random result. | ||||||
|  |     """ | ||||||
|  |     input_images = np.random.random((2, 5, 8, 3)).astype(np.float32) | ||||||
|  |     with tf_test_util.use_gpu(): | ||||||
|  |       strat = MirroredStrategy(devices=['cpu', 'gpu']) | ||||||
|  |       with strat.scope(): | ||||||
|  |         layer = image_preprocessing.RandomRotation(.5) | ||||||
|  |         output = strat.run(lambda: layer(input_images, training=True)) | ||||||
|  |       values = output.values | ||||||
|  |       self.assertAllEqual(2, len(values)) | ||||||
|  |       self.assertAllClose(values[0], values[1]) | ||||||
|  | 
 | ||||||
|   @tf_test_util.run_v2_only |   @tf_test_util.run_v2_only | ||||||
|   def test_config_with_custom_name(self): |   def test_config_with_custom_name(self): | ||||||
|     layer = image_preprocessing.RandomRotation(.5, name='image_preproc') |     layer = image_preprocessing.RandomRotation(.5, name='image_preproc') | ||||||
|  | |||||||
| @ -255,27 +255,6 @@ class GeneratorSpec(type_spec.TypeSpec): | |||||||
|     return (self.shape, self.dtype, self.alg) |     return (self.shape, self.dtype, self.alg) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _create_variable(*args, **kwargs): |  | ||||||
|   """Creates a variable, and check that it's not MirroredVariable. |  | ||||||
| 
 |  | ||||||
|   Args: |  | ||||||
|     *args: positional arguments passed along to `variables.Variable. |  | ||||||
|     **kwargs: keyword arguments passed along to `variables.Variable. |  | ||||||
| 
 |  | ||||||
|   Returns: |  | ||||||
|     The created variable. |  | ||||||
|   """ |  | ||||||
|   if ds_context.has_strategy(): |  | ||||||
|     raise ValueError( |  | ||||||
|         "Creating a generator within a strategy scope is disallowed, because " |  | ||||||
|         "there is ambiguity on how to replicate a generator (e.g. should it be " |  | ||||||
|         "copied so that each replica gets the same random numbers, or 'split' " |  | ||||||
|         "so that each replica gets different random numbers).") |  | ||||||
|     # TODO(wangpeng): Link to the RNG guide for solutions in such cases. |  | ||||||
|   var = variables.Variable(*args, **kwargs) |  | ||||||
|   return var |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @tf_export("random.Generator", "random.experimental.Generator") | @tf_export("random.Generator", "random.experimental.Generator") | ||||||
| class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): | class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): | ||||||
|   """Random-number generator. |   """Random-number generator. | ||||||
| @ -367,7 +346,7 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): | |||||||
|     if copy_from is not None: |     if copy_from is not None: | ||||||
|       # All other arguments should be None |       # All other arguments should be None | ||||||
|       assert (alg or state) is None |       assert (alg or state) is None | ||||||
|       self._state_var = _create_variable(copy_from.state, dtype=STATE_TYPE, |       self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE, | ||||||
|                                               trainable=False) |                                               trainable=False) | ||||||
|       self._alg = copy_from.algorithm |       self._alg = copy_from.algorithm | ||||||
| 
 | 
 | ||||||
| @ -380,10 +359,30 @@ class Generator(tracking.AutoTrackable, composite_tensor.CompositeTensor): | |||||||
|       else: |       else: | ||||||
|         state = _convert_to_state_tensor(state) |         state = _convert_to_state_tensor(state) | ||||||
|         _check_state_shape(state.shape, alg) |         _check_state_shape(state.shape, alg) | ||||||
|         self._state_var = _create_variable(state, dtype=STATE_TYPE, |         self._state_var = self._create_variable(state, dtype=STATE_TYPE, | ||||||
|                                                 trainable=False) |                                                 trainable=False) | ||||||
|       self._alg = alg |       self._alg = alg | ||||||
| 
 | 
 | ||||||
|  |   def _create_variable(self, *args, **kwargs): | ||||||
|  |     """Creates a variable, and check that it's not MirroredVariable. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |       *args: positional arguments passed along to `variables.Variable. | ||||||
|  |       **kwargs: keyword arguments passed along to `variables.Variable. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |       The created variable. | ||||||
|  |     """ | ||||||
|  |     if ds_context.has_strategy(): | ||||||
|  |       raise ValueError( | ||||||
|  |           "Creating a generator within a strategy scope is disallowed, because " | ||||||
|  |           "there is ambiguity on how to replicate a generator (e.g. should it " | ||||||
|  |           "be copied so that each replica gets the same random numbers, or " | ||||||
|  |           "'split' so that each replica gets different random numbers).") | ||||||
|  |       # TODO(wangpeng): Link to the RNG guide for solutions in such cases. | ||||||
|  |     var = variables.Variable(*args, **kwargs) | ||||||
|  |     return var | ||||||
|  | 
 | ||||||
|   @classmethod |   @classmethod | ||||||
|   def from_state(cls, state, alg): |   def from_state(cls, state, alg): | ||||||
|     """Creates a generator from a state. |     """Creates a generator from a state. | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user