From fc296acdc1d454596d9e0e531656858f3b0acca6 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 17 Jun 2020 09:00:01 -0700 Subject: [PATCH] Make axis handling for Normalization more robust. PiperOrigin-RevId: 316898233 Change-Id: I6888216ed21c4d2a482772fb2a314160750185b6 --- .../layers/preprocessing/normalization.py | 67 ++++++++++++------- .../preprocessing/normalization_test.py | 43 ++++++++++++ 2 files changed, 86 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/keras/layers/preprocessing/normalization.py b/tensorflow/python/keras/layers/preprocessing/normalization.py index 09564cbb064..ba2f7eaae89 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization.py @@ -53,10 +53,13 @@ class Normalization(CombinerPreprocessingLayer): Attributes: axis: Integer or tuple of integers, the axis or axes that should be - normalized (typically the features axis). We will normalize each element - in the specified axis. If set to 'None', the layer will perform scalar - normalization (diving the input by a single scalar value). 0 (the batch - axis) is not allowed. + "kept". These axes are not be summed over when calculating the + normalization statistics. By default the last axis, the `features` axis + is kept and any `space` or `time` axes are summed. Each element in the + the axes that are kept is normalized independently. If `axis` is set to + 'None', the layer will perform scalar normalization (diving the input + by a single scalar value). The `batch` axis, 0, is always summed over + (`axis=0` is not allowed). Examples: @@ -78,10 +81,18 @@ class Normalization(CombinerPreprocessingLayer): # time, the dtype value will change to reflect it. dtype = dtype or K.floatx() + # Standardize `axis` to a tuple. + if axis is None: + axis = () + elif isinstance(axis, int): + axis = (axis,) + else: + axis = tuple(axis) + super(Normalization, self).__init__( combiner=_NormalizingCombiner(axis), dtype=dtype, **kwargs) - if axis == 0: + if 0 in axis: raise ValueError('The argument \'axis\' may not be 0.') self.axis = axis @@ -90,18 +101,27 @@ class Normalization(CombinerPreprocessingLayer): input_shape = tensor_shape.TensorShape(input_shape).as_list() if len(input_shape) == 1: input_shape = input_shape + [1] + + ndim = len(input_shape) + + # Sort `self.axis` to avoid transposing `mean_and_var_shape`. + # Negative axes are not sortable until you know the number of dimensions. + original_axis = self.axis + self.axis = tuple(sorted(self.axis, + key=lambda a: a if a >= 0 else ndim + a)) + + if any(a < 1-ndim for a in self.axis) or any(a >= ndim for a in self.axis): + raise ValueError('All `axis` values must be in ' + 'the range [1-ndim, ndim-1].\n' + 'Got:\n' + ' ndim: {}\n' + ' axis: {}'.format(ndim, original_axis)) + self._broadcast_shape = [1 for _ in range(len(input_shape))] - if isinstance(self.axis, (tuple, list)): - mean_and_var_shape = [] - for i in self.axis: - mean_and_var_shape.append(input_shape[i]) - self._broadcast_shape[i] = input_shape[i] - else: - if self.axis is None: - mean_and_var_shape = () - else: - mean_and_var_shape = input_shape[self.axis] - self._broadcast_shape[self.axis] = input_shape[self.axis] + mean_and_var_shape = [] + for i in self.axis: + mean_and_var_shape.append(input_shape[i]) + self._broadcast_shape[i] = input_shape[i] # count is not used in this class's call() method, but is used to re-create # the accumulator during multiple calls to 'adapt'. @@ -179,11 +199,13 @@ class _NormalizingCombiner(Combiner): if values.ndim == 1: values = np.expand_dims(values, 1) + # `np.delete` ignores negative indexes, so use a mask to delete items. + axis_mask = np.ones([values.ndim], dtype=bool) + axis_mask[np.array(self.axis, dtype=np.int32)] = False + # This is the shape of all reduced axes (not specified in 'axis'). - if self.axis is None: - reduction_counts = values.shape - else: - reduction_counts = np.delete(values.shape, self.axis) + + reduction_counts = np.array(values.shape)[axis_mask] # We get the number of elements that will be reduced by multiplying all # values of 'shape' corresponding to the reduced axes. count = np.prod(reduction_counts, dtype=np.int64) @@ -191,10 +213,7 @@ class _NormalizingCombiner(Combiner): # We want to reduce across dimensions except those specified in 'axis' # when using np.mean or np.variance; create the tuple of axes to reduce # over here. - if self.axis is None: - reduction_axes = None - else: - reduction_axes = tuple(np.delete(range(values.ndim), self.axis)) + reduction_axes = tuple(np.arange(values.ndim)[axis_mask]) mean = np.mean(values, axis=reduction_axes, dtype=np.float64) variance = np.var(values, axis=reduction_axes, dtype=np.float64) diff --git a/tensorflow/python/keras/layers/preprocessing/normalization_test.py b/tensorflow/python/keras/layers/preprocessing/normalization_test.py index 75ef9370899..f5f68d9c51a 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization_test.py @@ -275,6 +275,49 @@ class NormalizationTest(keras_parameterized.TestCase, if context.executing_eagerly(): self.assertAllClose(output.numpy(), [[-1], [1], [-1], [1]]) + @parameterized.parameters( + {"axis": 0}, + {"axis": (-1, 0)}, + ) + def test_zeros_fail_init(self, axis): + cls = get_layer_class() + with self.assertRaisesRegex(ValueError, + "The argument 'axis' may not be 0."): + cls(axis=axis) + + @parameterized.parameters( + # Out of bounds + {"axis": 3}, + {"axis": -3}, + # In a tuple + {"axis": (1, 3)}, + {"axis": (1, -3)}, + ) + def test_bad_axis_fail_build(self, axis): + cls = get_layer_class() + layer = cls(axis=axis) + with self.assertRaisesRegex(ValueError, + r"in the range \[1-ndim, ndim-1\]."): + layer.build([None, 2, 3]) + + @parameterized.parameters( + # Results should be identical no matter how the axes are specified (3d). + {"axis": (1, 2)}, + {"axis": (2, 1)}, + {"axis": (1, -1)}, + {"axis": (-1, 1)}, + ) + def test_axis_permutations(self, axis): + cls = get_layer_class() + layer = cls(axis=axis) + # data.shape = [2, 2, 3] + data = np.array([[[0., 1., 2.], [0., 2., 6.]], + [[2., 3., 4.], [3., 6., 10.]]]) + expect = np.array([[[-1., -1., -1.], [-1., -1., -1.]], + [[1., 1., 1.], [1., 1., 1.]]]) + layer.adapt(data) + self.assertAllClose(expect, layer(data)) + if __name__ == "__main__": test.main()