Create Image Preproc RandomFlip layer.
PiperOrigin-RevId: 285843812 Change-Id: Ifd8183f916d673a7e661ca5dca84312c5cc74078
This commit is contained in:
parent
ea0c4e6106
commit
1d61880a20
@ -313,6 +313,75 @@ class Rescaling(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
class RandomFlip(Layer):
|
||||
"""Randomly flip each image horizontally and vertically.
|
||||
|
||||
This layer will by default flip the images horizontally and then vertically
|
||||
during training time.
|
||||
`RandomFlip(horizontal=True)` will only flip the input horizontally.
|
||||
`RandomFlip(vertical=True)` will only flip the input vertically.
|
||||
During inference time, the output will be identical to input. Call the layer
|
||||
with `training=True` to flip the input.
|
||||
|
||||
Input shape:
|
||||
4D tensor with shape:
|
||||
`(samples, height, width, channels)`, data_format='channels_last'.
|
||||
|
||||
Output shape:
|
||||
4D tensor with shape:
|
||||
`(samples, height, width, channels)`, data_format='channels_last'.
|
||||
|
||||
Attributes:
|
||||
horizontal: Bool, whether to randomly flip horizontally.
|
||||
width: Bool, whether to randomly flip vertically.
|
||||
seed: Integer. Used to create a random seed.
|
||||
"""
|
||||
|
||||
def __init__(self, horizontal=None, vertical=None, seed=None, **kwargs):
|
||||
# If both arguments are None, set both to True.
|
||||
if horizontal is None and vertical is None:
|
||||
self.horizontal = True
|
||||
self.vertical = True
|
||||
else:
|
||||
self.horizontal = horizontal or False
|
||||
self.vertical = vertical or False
|
||||
self.seed = seed
|
||||
self._rng = make_generator(self.seed)
|
||||
self.input_spec = InputSpec(ndim=4)
|
||||
super(RandomFlip, self).__init__(**kwargs)
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
|
||||
def random_flipped_inputs():
|
||||
flipped_outputs = inputs
|
||||
if self.horizontal:
|
||||
flipped_outputs = image_ops.random_flip_up_down(flipped_outputs,
|
||||
self.seed)
|
||||
if self.vertical:
|
||||
flipped_outputs = image_ops.random_flip_left_right(
|
||||
flipped_outputs, self.seed)
|
||||
return flipped_outputs
|
||||
|
||||
output = tf_utils.smart_cond(training, random_flipped_inputs,
|
||||
lambda: inputs)
|
||||
output.set_shape(inputs.shape)
|
||||
return output
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'horizontal': self.horizontal,
|
||||
'vertical': self.vertical,
|
||||
'seed': self.seed,
|
||||
}
|
||||
base_config = super(RandomFlip, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
|
||||
def make_generator(seed=None):
|
||||
if seed:
|
||||
return stateful_random_ops.Generator.from_seed(seed)
|
||||
|
||||
@ -286,5 +286,93 @@ class RescalingTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(layer_1.name, layer.name)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class RandomFlipTest(keras_parameterized.TestCase):
|
||||
|
||||
def _run_test(self,
|
||||
flip_horizontal,
|
||||
flip_vertical,
|
||||
expected_output=None,
|
||||
mock_random=None):
|
||||
np.random.seed(1337)
|
||||
num_samples = 2
|
||||
orig_height = 5
|
||||
orig_width = 8
|
||||
channels = 3
|
||||
if mock_random is None:
|
||||
mock_random = [1 for _ in range(num_samples)]
|
||||
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
|
||||
inp = np.random.random((num_samples, orig_height, orig_width, channels))
|
||||
if expected_output is None:
|
||||
expected_output = inp
|
||||
if flip_horizontal:
|
||||
expected_output = np.flip(expected_output, axis=1)
|
||||
if flip_vertical:
|
||||
expected_output = np.flip(expected_output, axis=2)
|
||||
with test.mock.patch.object(
|
||||
random_ops, 'random_uniform', return_value=mock_random):
|
||||
with tf_test_util.use_gpu():
|
||||
layer = image_preprocessing.RandomFlip(flip_horizontal, flip_vertical)
|
||||
actual_output = layer(inp, training=1)
|
||||
self.assertAllClose(expected_output, actual_output)
|
||||
|
||||
@parameterized.named_parameters(('random_flip_horizontal', True, False),
|
||||
('random_flip_vertical', False, True),
|
||||
('random_flip_both', True, True),
|
||||
('random_flip_neither', False, False))
|
||||
def test_random_flip(self, flip_horizontal, flip_vertical):
|
||||
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
|
||||
self._run_test(flip_horizontal, flip_vertical)
|
||||
|
||||
def test_random_flip_horizontal_half(self):
|
||||
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
|
||||
np.random.seed(1337)
|
||||
mock_random = [1, 0]
|
||||
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
|
||||
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||
expected_output = input_images.copy()
|
||||
expected_output[0, :, :, :] = np.flip(input_images[0, :, :, :], axis=0)
|
||||
self._run_test(True, False, expected_output, mock_random)
|
||||
|
||||
def test_random_flip_vertical_half(self):
|
||||
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
|
||||
np.random.seed(1337)
|
||||
mock_random = [1, 0]
|
||||
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
|
||||
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||
expected_output = input_images.copy()
|
||||
expected_output[0, :, :, :] = np.flip(input_images[0, :, :, :], axis=1)
|
||||
self._run_test(False, True, expected_output, mock_random)
|
||||
|
||||
def test_random_flip_inference(self):
|
||||
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
|
||||
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||
expected_output = input_images
|
||||
with tf_test_util.use_gpu():
|
||||
layer = image_preprocessing.RandomFlip(True, True)
|
||||
actual_output = layer(input_images, training=0)
|
||||
self.assertAllClose(expected_output, actual_output)
|
||||
|
||||
def test_random_flip_default(self):
|
||||
with CustomObjectScope({'RandomFlip': image_preprocessing.RandomFlip}):
|
||||
input_images = np.random.random((2, 5, 8, 3)).astype(np.float32)
|
||||
expected_output = np.flip(np.flip(input_images, axis=1), axis=2)
|
||||
mock_random = [1, 1]
|
||||
mock_random = np.reshape(mock_random, [2, 1, 1, 1])
|
||||
with test.mock.patch.object(
|
||||
random_ops, 'random_uniform', return_value=mock_random):
|
||||
with self.cached_session(use_gpu=True):
|
||||
layer = image_preprocessing.RandomFlip()
|
||||
actual_output = layer(input_images, training=1)
|
||||
self.assertAllClose(expected_output, actual_output)
|
||||
|
||||
@tf_test_util.run_v2_only
|
||||
def test_config_with_custom_name(self):
|
||||
layer = image_preprocessing.RandomFlip(5, 5, name='image_preproc')
|
||||
config = layer.get_config()
|
||||
layer_1 = image_preprocessing.RandomFlip.from_config(config)
|
||||
self.assertEqual(layer_1.name, layer.name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user