KPL - Implement Normalization with new APIs.
PiperOrigin-RevId: 356351476 Change-Id: I66f061a36cf096239fe2b0f2aaa380dba55a7569
This commit is contained in:
parent
2cc0ab3c0c
commit
236369d651
@ -18,8 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -30,26 +28,13 @@ from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
_COUNT_NAME = 'count'
|
||||
_MEAN_NAME = 'mean'
|
||||
_VARIANCE_NAME = 'variance'
|
||||
|
||||
|
||||
def convert_to_ndarray(values):
|
||||
if isinstance(values, np.ndarray):
|
||||
return values
|
||||
elif isinstance(values, ops.Tensor):
|
||||
return K.get_value(values)
|
||||
else:
|
||||
return np.array(values)
|
||||
|
||||
|
||||
@keras_export('keras.layers.experimental.preprocessing.Normalization', v1=[])
|
||||
class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
class Normalization(base_preprocessing_layer.PreprocessingLayer):
|
||||
"""Feature-wise normalization of the data.
|
||||
|
||||
This layer will coerce its inputs into a distribution centered around
|
||||
@ -104,6 +89,9 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
"""
|
||||
|
||||
def __init__(self, axis=-1, mean=None, variance=None, **kwargs):
|
||||
super(Normalization, self).__init__(stateful=True, streaming=True, **kwargs)
|
||||
base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
|
||||
|
||||
# Standardize `axis` to a tuple.
|
||||
if axis is None:
|
||||
axis = ()
|
||||
@ -111,23 +99,17 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
axis = (axis,)
|
||||
else:
|
||||
axis = tuple(axis)
|
||||
|
||||
super(Normalization, self).__init__(
|
||||
combiner=_NormalizingCombiner(axis), **kwargs)
|
||||
base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
|
||||
|
||||
if 0 in axis:
|
||||
raise ValueError('The argument \'axis\' may not be 0.')
|
||||
|
||||
self.axis = axis
|
||||
|
||||
# Set `mean` and `variance` if passed.
|
||||
if isinstance(mean, variables.Variable):
|
||||
raise ValueError('Normalization does not support passing a Variable '
|
||||
'for the `mean` init arg.')
|
||||
if isinstance(variance, variables.Variable):
|
||||
raise ValueError('Normalization does not support passing a Variable '
|
||||
'for the `variance` init arg.')
|
||||
|
||||
if mean is not None and variance is not None:
|
||||
mean = convert_to_ndarray(mean)
|
||||
variance = convert_to_ndarray(variance)
|
||||
@ -135,7 +117,6 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
raise ValueError(
|
||||
'When setting values directly, both `mean` and `variance` '
|
||||
'must be set. Got mean: {} and variance: {}'.format(mean, variance))
|
||||
|
||||
self.mean_val = mean
|
||||
self.variance_val = variance
|
||||
|
||||
@ -143,62 +124,121 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if len(input_shape) == 1:
|
||||
input_shape = input_shape + [1]
|
||||
|
||||
ndim = len(input_shape)
|
||||
|
||||
# Sort `self.axis` to avoid transposing `mean_and_var_shape`.
|
||||
# Negative axes are not sortable until you know the number of dimensions.
|
||||
original_axis = self.axis
|
||||
self.axis = tuple(sorted(self.axis,
|
||||
key=lambda a: a if a >= 0 else ndim + a))
|
||||
if any(a < 1 - ndim or a >= ndim for a in self.axis):
|
||||
raise ValueError('All `axis` values must be in the range '
|
||||
'[1 - ndim, ndim - 1]. Found '
|
||||
'ndim: `{}`, axis: {}'.format(ndim, self.axis))
|
||||
|
||||
if any(a < 1-ndim for a in self.axis) or any(a >= ndim for a in self.axis):
|
||||
raise ValueError('All `axis` values must be in '
|
||||
'the range [1-ndim, ndim-1].\n'
|
||||
'Got:\n'
|
||||
' ndim: {}\n'
|
||||
' axis: {}'.format(ndim, original_axis))
|
||||
# Axes to be kept, replacing negative values with positive equivalents.
|
||||
# Sorted to avoid transposing axes.
|
||||
self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis])
|
||||
# Axes to be reduced.
|
||||
self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis]
|
||||
# 1 if an axis should be reduced, 0 otherwise.
|
||||
self._reduce_axis_mask = [
|
||||
0 if d in self._keep_axis else 1 for d in range(ndim)
|
||||
]
|
||||
# Broadcast any reduced axes.
|
||||
self._broadcast_shape = [
|
||||
input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)
|
||||
]
|
||||
# Create variables without keeping reduced axes.
|
||||
mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)
|
||||
|
||||
self._broadcast_shape = [1 for _ in range(len(input_shape))]
|
||||
mean_and_var_shape = []
|
||||
for i in self.axis:
|
||||
mean_and_var_shape.append(input_shape[i])
|
||||
self._broadcast_shape[i] = input_shape[i]
|
||||
|
||||
# count is not used in this class's call() method, but is used to re-create
|
||||
# the accumulator during multiple calls to 'adapt'.
|
||||
# TODO(omalleyt): should mean and variance be set to self.dtype?
|
||||
self.mean = self._add_state_variable(
|
||||
name=_MEAN_NAME,
|
||||
self.mean = self.add_weight(
|
||||
name='mean',
|
||||
shape=mean_and_var_shape,
|
||||
dtype=K.floatx(),
|
||||
initializer=init_ops.zeros_initializer)
|
||||
self.variance = self._add_state_variable(
|
||||
name=_VARIANCE_NAME,
|
||||
dtype=self.dtype,
|
||||
initializer=init_ops.zeros_initializer,
|
||||
trainable=False)
|
||||
self.variance = self.add_weight(
|
||||
name='variance',
|
||||
shape=mean_and_var_shape,
|
||||
dtype=K.floatx(),
|
||||
initializer=init_ops.ones_initializer)
|
||||
self.count = self._add_state_variable(
|
||||
name=_COUNT_NAME,
|
||||
dtype=self.dtype,
|
||||
initializer=init_ops.ones_initializer,
|
||||
trainable=False)
|
||||
self.count = self.add_weight(
|
||||
name='count',
|
||||
shape=(),
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
initializer=init_ops.zeros_initializer,
|
||||
trainable=False)
|
||||
|
||||
super(Normalization, self).build(input_shape)
|
||||
|
||||
if (self.mean_val is not None and self.variance_val is not None):
|
||||
mean_val = self.mean_val * np.ones(mean_and_var_shape)
|
||||
variance_val = self.variance_val * np.ones(mean_and_var_shape)
|
||||
self.set_weights([mean_val, variance_val])
|
||||
self.mean.assign(mean_val)
|
||||
self.variance.assign(variance_val)
|
||||
|
||||
self.built = True
|
||||
|
||||
def update_state(self, data):
|
||||
if not self.built:
|
||||
raise RuntimeError('`build` must be called before `update_state`.')
|
||||
|
||||
data = self._standardize_inputs(data)
|
||||
batch_mean, batch_variance = nn_impl.moments_v2(
|
||||
data, axes=self._reduce_axis)
|
||||
batch_shape = array_ops.shape(data, out_type=self.count.dtype)
|
||||
batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis)
|
||||
batch_count = math_ops.reduce_prod(batch_reduce_shape)
|
||||
|
||||
total_count = batch_count + self.count
|
||||
batch_weight = (
|
||||
math_ops.cast(batch_count, dtype=self.dtype) /
|
||||
math_ops.cast(total_count, dtype=self.dtype))
|
||||
existing_weight = 1. - batch_weight
|
||||
|
||||
total_mean = self.mean * existing_weight + batch_mean * batch_weight
|
||||
# The variance is computed using the lack-of-fit sum of squares
|
||||
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
||||
total_variance = ((self.variance +
|
||||
(self.mean - total_mean)**2) * existing_weight +
|
||||
(batch_variance +
|
||||
(batch_mean - total_mean)**2) * batch_weight)
|
||||
self.mean.assign(total_mean)
|
||||
self.variance.assign(total_variance)
|
||||
self.count.assign(total_count)
|
||||
|
||||
def merge_state(self, layers):
|
||||
layers = layers + [self]
|
||||
if any(not l.built for l in layers):
|
||||
raise ValueError(
|
||||
'All layers to be merged must have been adapted to some inputs '
|
||||
'first (otherwise they have no state).')
|
||||
|
||||
layer_counts = [l.count for l in layers]
|
||||
layer_means = [l.mean for l in layers]
|
||||
layer_variances = [l.variance for l in layers]
|
||||
|
||||
total_count = math_ops.reduce_sum(layer_counts)
|
||||
layer_weightings = (
|
||||
math_ops.cast(layer_counts, self.dtype) /
|
||||
math_ops.cast(total_count, self.dtype))
|
||||
layer_weightings = array_ops.reshape(
|
||||
layer_weightings, shape=[len(layers)] + [1] * self.mean.shape.rank)
|
||||
|
||||
total_mean = math_ops.reduce_sum(layer_means * layer_weightings, axis=0)
|
||||
inter_layer_variances = (layer_means - total_mean)**2
|
||||
total_variance = math_ops.reduce_sum(
|
||||
((layer_variances + inter_layer_variances) * layer_weightings), axis=0)
|
||||
|
||||
self.mean.assign(total_mean)
|
||||
self.variance.assign(total_variance)
|
||||
self.count.assign(total_count)
|
||||
|
||||
def reset_state(self): # pylint: disable=method-hidden
|
||||
if self.built:
|
||||
self.mean.assign(array_ops.zeros_like(self.mean))
|
||||
self.variance.assign(array_ops.ones_like(self.variance))
|
||||
self.count.assign(array_ops.zeros_like(self.count))
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
|
||||
if inputs.shape.rank == 1:
|
||||
inputs = array_ops.expand_dims(inputs, 1)
|
||||
# If the inputs are not floats, cast them to floats. This avoids issues
|
||||
# with int-float multiplication and division below.
|
||||
if inputs.dtype != K.floatx():
|
||||
inputs = math_ops.cast(inputs, K.floatx())
|
||||
inputs = self._standardize_inputs(inputs)
|
||||
# We need to reshape the mean and variance data to ensure that Tensorflow
|
||||
# broadcasts the data correctly.
|
||||
mean = array_ops.reshape(self.mean, self._broadcast_shape)
|
||||
@ -213,9 +253,9 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
return input_spec
|
||||
|
||||
def get_config(self):
|
||||
config = {'axis': self.axis}
|
||||
base_config = super(Normalization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
config = super(Normalization, self).get_config()
|
||||
config.update({'axis': self.axis})
|
||||
return config
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Override for set_weights to ensure we can set just mean/var weights."""
|
||||
@ -223,149 +263,22 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
weights.append(np.array(0))
|
||||
super(Normalization, self).set_weights(weights)
|
||||
|
||||
def _standardize_inputs(self, inputs):
|
||||
inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
|
||||
if inputs.shape.rank == 0:
|
||||
inputs = array_ops.reshape(inputs, [1, 1])
|
||||
elif inputs.shape.rank == 1:
|
||||
inputs = array_ops.expand_dims(inputs, 1)
|
||||
|
||||
class _NormalizingCombiner(base_preprocessing_layer.Combiner):
|
||||
"""Combiner for the Normalization preprocessing layer.
|
||||
if inputs.dtype != self.dtype:
|
||||
inputs = math_ops.cast(inputs, self.dtype)
|
||||
return inputs
|
||||
|
||||
This class encapsulates the computations for finding the mean and variance
|
||||
of a set of data in a stable and numerically correct way. Its associated
|
||||
accumulator is a namedtuple('count', 'mean', 'variance').
|
||||
|
||||
Attributes:
|
||||
axis: The axis to compute mean and var over.
|
||||
"""
|
||||
COUNT_IDX = 0
|
||||
MEAN_IDX = 1
|
||||
VAR_IDX = 2
|
||||
|
||||
def __init__(self, axis):
|
||||
self.axis = axis
|
||||
|
||||
def compute(self, values, accumulator=None):
|
||||
"""Compute a step in this computation, returning a new accumulator."""
|
||||
values = np.array(values)
|
||||
if values.ndim == 1:
|
||||
values = np.expand_dims(values, 1)
|
||||
|
||||
# `np.delete` ignores negative indexes, so use a mask to delete items.
|
||||
axis_mask = np.ones([values.ndim], dtype=bool)
|
||||
axis_mask[np.array(self.axis, dtype=np.int32)] = False
|
||||
|
||||
# This is the shape of all reduced axes (not specified in 'axis').
|
||||
|
||||
reduction_counts = np.array(values.shape)[axis_mask]
|
||||
# We get the number of elements that will be reduced by multiplying all
|
||||
# values of 'shape' corresponding to the reduced axes.
|
||||
count = np.prod(reduction_counts, dtype=np.int64)
|
||||
|
||||
# We want to reduce across dimensions except those specified in 'axis'
|
||||
# when using np.mean or np.variance; create the tuple of axes to reduce
|
||||
# over here.
|
||||
reduction_axes = tuple(np.arange(values.ndim)[axis_mask])
|
||||
|
||||
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
|
||||
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
|
||||
|
||||
# Create an accumulator with our new data and either return it or combine
|
||||
# it with the passed accumulator.
|
||||
if accumulator is None:
|
||||
return self._create_accumulator(count, mean, variance)
|
||||
def convert_to_ndarray(values):
|
||||
if isinstance(values, np.ndarray):
|
||||
return values
|
||||
elif isinstance(values, ops.Tensor):
|
||||
return K.get_value(values)
|
||||
else:
|
||||
return self.add_data_to_accumulator(count, mean, variance, accumulator)
|
||||
|
||||
def add_data_to_accumulator(self, count, mean, variance, accumulator):
|
||||
"""Add new data to the totals in an accumulator."""
|
||||
# Combine accumulators and return the result.
|
||||
combined_count = count + accumulator[self.COUNT_IDX]
|
||||
|
||||
# To combine accumulator means, we weight each accumulator's mean by the
|
||||
# number of elements that were accumulated, and then divide by the
|
||||
# total number of elements.
|
||||
combined_mean = (mean * count + accumulator[self.MEAN_IDX] *
|
||||
accumulator[self.COUNT_IDX]) / combined_count
|
||||
|
||||
# The variance is computed using the lack-of-fit sum of squares
|
||||
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
||||
accumulator_var_contribution = accumulator[self.COUNT_IDX] * (
|
||||
accumulator[self.VAR_IDX] +
|
||||
np.square(accumulator[self.MEAN_IDX] - combined_mean))
|
||||
data_var_contribution = count * (variance + np.square(mean - combined_mean))
|
||||
combined_variance = (accumulator_var_contribution +
|
||||
data_var_contribution) / combined_count
|
||||
|
||||
accumulator[self.COUNT_IDX] = combined_count
|
||||
accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean)
|
||||
accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance)
|
||||
return accumulator
|
||||
|
||||
def merge(self, accumulators):
|
||||
"""Merge several accumulators to a single accumulator."""
|
||||
# Combine accumulators and return the result.
|
||||
combined_count = np.sum(
|
||||
[accumulator[self.COUNT_IDX] for accumulator in accumulators])
|
||||
|
||||
# To combine accumulator means, we weight each accumulator's mean by the
|
||||
# number of elements that were accumulated, and then divide by the
|
||||
# total number of elements.
|
||||
combined_mean = np.add.reduce([
|
||||
accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX]
|
||||
for accumulator in accumulators
|
||||
]) / combined_count
|
||||
|
||||
# The variance is computed using the lack-of-fit sum of squares
|
||||
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
||||
def variance_contribution(accumulator):
|
||||
return accumulator[self.COUNT_IDX] * (
|
||||
accumulator[self.VAR_IDX] +
|
||||
np.square(accumulator[self.MEAN_IDX] - combined_mean))
|
||||
|
||||
combined_variance = np.add.reduce([
|
||||
variance_contribution(accumulator) for accumulator in accumulators
|
||||
]) / combined_count
|
||||
|
||||
return self._create_accumulator(combined_count, combined_mean,
|
||||
combined_variance)
|
||||
|
||||
def extract(self, accumulator):
|
||||
"""Convert an accumulator into a dict of output values."""
|
||||
return {
|
||||
_COUNT_NAME: accumulator[self.COUNT_IDX],
|
||||
_MEAN_NAME: accumulator[self.MEAN_IDX],
|
||||
_VARIANCE_NAME: accumulator[self.VAR_IDX]
|
||||
}
|
||||
|
||||
def restore(self, output):
|
||||
"""Create an accumulator based on 'output'."""
|
||||
# There is no special internal state here, so we just return the relevant
|
||||
# internal value.
|
||||
count = output[_COUNT_NAME]
|
||||
mean = output[_MEAN_NAME]
|
||||
var = output[_VARIANCE_NAME]
|
||||
if (count == 0 and (mean.any() != 0.0 or var.any() != 0.0)):
|
||||
raise RuntimeError(
|
||||
'The mean and/or variance of a Normalization preprocessing layer '
|
||||
"were set without also setting 'count'. If 'count' is not also set, "
|
||||
" or was set to 0, 'adapt' cannot be called unless the 'reset_state'"
|
||||
'arg is True.')
|
||||
return self._create_accumulator(output[_COUNT_NAME], output[_MEAN_NAME],
|
||||
output[_VARIANCE_NAME])
|
||||
|
||||
def serialize(self, accumulator):
|
||||
"""Serialize an accumulator for a remote call."""
|
||||
output_dict = {
|
||||
_COUNT_NAME: accumulator[self.COUNT_IDX].tolist(),
|
||||
_MEAN_NAME: accumulator[self.MEAN_IDX].tolist(),
|
||||
_VARIANCE_NAME: accumulator[self.VAR_IDX].tolist()
|
||||
}
|
||||
return compat.as_bytes(json.dumps(output_dict))
|
||||
|
||||
def deserialize(self, encoded_accumulator):
|
||||
"""Deserialize an accumulator received from 'serialize()'."""
|
||||
value_dict = json.loads(compat.as_text(encoded_accumulator))
|
||||
return self._create_accumulator(
|
||||
np.array(value_dict[_COUNT_NAME]), np.array(value_dict[_MEAN_NAME]),
|
||||
np.array(value_dict[_VARIANCE_NAME]))
|
||||
|
||||
def _create_accumulator(self, count, mean, variance):
|
||||
"""Convert any 'nan' values in the given accumulator to numeric values."""
|
||||
return [count, mean, variance]
|
||||
return np.array(values)
|
||||
|
@ -141,80 +141,6 @@ class NormalizationTest(keras_parameterized.TestCase,
|
||||
expected = np.array([[3., -3., -0.33333333], [9., 5., 1.]])
|
||||
self.assertAllClose(expected, output_data)
|
||||
|
||||
def test_combiner_api_compatibility(self):
|
||||
data = np.array([[1], [2], [3], [4], [5]])
|
||||
combiner = normalization._NormalizingCombiner(axis=-1)
|
||||
expected = {
|
||||
"count": np.array(5.0),
|
||||
"variance": np.array([2.]),
|
||||
"mean": np.array([3.])
|
||||
}
|
||||
expected_accumulator = combiner._create_accumulator(expected["count"],
|
||||
expected["mean"],
|
||||
expected["variance"])
|
||||
self.validate_accumulator_serialize_and_deserialize(combiner, data,
|
||||
expected_accumulator)
|
||||
self.validate_accumulator_uniqueness(combiner, data)
|
||||
self.validate_accumulator_extract(combiner, data, expected)
|
||||
self.validate_accumulator_extract_and_restore(combiner, data,
|
||||
expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"data": np.array([[1], [2], [3], [4], [5]]),
|
||||
"axis": -1,
|
||||
"expected": {
|
||||
"count": np.array(5.0),
|
||||
"variance": np.array([2.]),
|
||||
"mean": np.array([3.])
|
||||
},
|
||||
"testcase_name": "2d_single_element"
|
||||
}, {
|
||||
"data": np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]),
|
||||
"axis": -1,
|
||||
"expected": {
|
||||
"count": np.array(5.0),
|
||||
"mean": np.array([3., 4.]),
|
||||
"variance": np.array([2., 2.])
|
||||
},
|
||||
"testcase_name": "2d_multi_element"
|
||||
}, {
|
||||
"data": np.array([[[1, 2]], [[2, 3]], [[3, 4]], [[4, 5]], [[5, 6]]]),
|
||||
"axis": 2,
|
||||
"expected": {
|
||||
"count": np.array(5.0),
|
||||
"mean": np.array([3., 4.]),
|
||||
"variance": np.array([2., 2.])
|
||||
},
|
||||
"testcase_name": "3d_multi_element"
|
||||
}, {
|
||||
"data": np.array([[[1, 2]], [[2, 3]], [[3, 4]], [[4, 5]], [[5, 6]]]),
|
||||
"axis": (1, 2),
|
||||
"expected": {
|
||||
"count": np.array(5.0),
|
||||
"mean": np.array([[3., 4.]]),
|
||||
"variance": np.array([[2., 2.]])
|
||||
},
|
||||
"testcase_name": "3d_multi_element_multi_axis"
|
||||
}, {
|
||||
"data":
|
||||
np.array([[[1, 2], [2, 3]], [[3, 4], [4, 5]], [[1, 2], [2, 3]],
|
||||
[[3, 4], [4, 5]]]),
|
||||
"axis":
|
||||
1,
|
||||
"expected": {
|
||||
"count": np.array(8.0),
|
||||
"mean": np.array([2.5, 3.5]),
|
||||
"variance": np.array([1.25, 1.25])
|
||||
},
|
||||
"testcase_name":
|
||||
"3d_multi_element_internal_axis"
|
||||
})
|
||||
def test_combiner_computation_multi_value_axis(self, data, axis, expected):
|
||||
combiner = normalization._NormalizingCombiner(axis=axis)
|
||||
expected_accumulator = combiner._create_accumulator(**expected)
|
||||
self.validate_accumulator_computation(combiner, data, expected_accumulator)
|
||||
|
||||
@parameterized.named_parameters(*_get_layer_computation_test_cases())
|
||||
def test_layer_computation(self, adapt_data, axis, test_data, use_dataset,
|
||||
expected):
|
||||
@ -286,6 +212,18 @@ class NormalizationTest(keras_parameterized.TestCase,
|
||||
if context.executing_eagerly():
|
||||
self.assertAllClose(output.numpy(), [[-1], [1], [-1], [1]])
|
||||
|
||||
def test_0d_data(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("Only supported in TF2.")
|
||||
|
||||
data = [0, 2, 0, 2]
|
||||
cls = get_layer_class()
|
||||
layer = cls(axis=-1)
|
||||
layer.adapt(data)
|
||||
output = layer(0.)
|
||||
self.assertListEqual(output.shape.as_list(), [1, 1])
|
||||
self.assertAllClose(output.numpy(), [[-1]])
|
||||
|
||||
@parameterized.parameters(
|
||||
{"axis": 0},
|
||||
{"axis": (-1, 0)},
|
||||
@ -307,8 +245,7 @@ class NormalizationTest(keras_parameterized.TestCase,
|
||||
def test_bad_axis_fail_build(self, axis):
|
||||
cls = get_layer_class()
|
||||
layer = cls(axis=axis)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"in the range \[1-ndim, ndim-1\]."):
|
||||
with self.assertRaisesRegex(ValueError, r"in the range"):
|
||||
layer.build([None, 2, 3])
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -341,6 +278,33 @@ class NormalizationTest(keras_parameterized.TestCase,
|
||||
keras.layers.Dense(1)])
|
||||
model.summary()
|
||||
|
||||
def test_merge_state(self):
|
||||
if not context.executing_eagerly():
|
||||
self.skipTest("`merge_state` only supported in TF2")
|
||||
|
||||
cls = get_layer_class()
|
||||
data = np.random.rand(30, 10, 2)
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(data).batch(2)
|
||||
norm = cls(axis=(1, 2))
|
||||
norm.adapt(ds)
|
||||
|
||||
partial_ds_1 = ds.shard(3, 0)
|
||||
partial_ds_2 = ds.shard(3, 1)
|
||||
partial_ds_3 = ds.shard(3, 2)
|
||||
|
||||
norm_1 = cls(axis=(1, 2))
|
||||
norm_2 = cls(axis=(1, 2))
|
||||
norm_3 = cls(axis=(1, 2))
|
||||
|
||||
norm_1.adapt(partial_ds_1)
|
||||
norm_2.adapt(partial_ds_2)
|
||||
norm_3.adapt(partial_ds_3)
|
||||
|
||||
norm_1.merge_state([norm_2, norm_3])
|
||||
merged_norm = norm_1
|
||||
|
||||
self.assertAllClose(norm(data), merged_norm(data))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -18,16 +18,326 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import base_preprocessing_layer
|
||||
from tensorflow.python.keras.engine.base_preprocessing_layer_v1 import CombinerPreprocessingLayer
|
||||
from tensorflow.python.keras.layers.preprocessing import normalization
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
_COUNT_NAME = 'count'
|
||||
_MEAN_NAME = 'mean'
|
||||
_VARIANCE_NAME = 'variance'
|
||||
|
||||
|
||||
def convert_to_ndarray(values):
|
||||
if isinstance(values, np.ndarray):
|
||||
return values
|
||||
elif isinstance(values, ops.Tensor):
|
||||
return K.get_value(values)
|
||||
else:
|
||||
return np.array(values)
|
||||
|
||||
|
||||
@keras_export(v1=['keras.layers.experimental.preprocessing.Normalization'])
|
||||
class Normalization(normalization.Normalization, CombinerPreprocessingLayer):
|
||||
class Normalization(CombinerPreprocessingLayer):
|
||||
"""Feature-wise normalization of the data.
|
||||
|
||||
This layer will coerce its inputs into a distribution centered around
|
||||
0 with standard deviation 1. It accomplishes this by precomputing the mean and
|
||||
variance of the data, and calling (input-mean)/sqrt(var) at runtime.
|
||||
|
||||
What happens in `adapt`: Compute mean and variance of the data and store them
|
||||
as the layer's weights. `adapt` should be called before `fit`, `evaluate`,
|
||||
or `predict`.
|
||||
|
||||
Attributes:
|
||||
axis: Integer or tuple of integers, the axis or axes that should be
|
||||
"kept". These axes are not be summed over when calculating the
|
||||
normalization statistics. By default the last axis, the `features` axis
|
||||
is kept and any `space` or `time` axes are summed. Each element in the
|
||||
the axes that are kept is normalized independently. If `axis` is set to
|
||||
'None', the layer will perform scalar normalization (dividing the input
|
||||
by a single scalar value). The `batch` axis, 0, is always summed over
|
||||
(`axis=0` is not allowed).
|
||||
"""
|
||||
|
||||
def __init__(self, axis=-1, **kwargs):
|
||||
super(Normalization, self).__init__(axis, **kwargs)
|
||||
base_preprocessing_layer.keras_kpl_gauge.get_cell(
|
||||
'Normalization v1').set(True)
|
||||
# Standardize `axis` to a tuple.
|
||||
if axis is None:
|
||||
axis = ()
|
||||
elif isinstance(axis, int):
|
||||
axis = (axis,)
|
||||
else:
|
||||
axis = tuple(axis)
|
||||
|
||||
mean = kwargs.pop('mean', None)
|
||||
variance = kwargs.pop('variance', None)
|
||||
|
||||
super(Normalization, self).__init__(
|
||||
combiner=_NormalizingCombiner(axis), **kwargs)
|
||||
base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
|
||||
|
||||
if 0 in axis:
|
||||
raise ValueError('The argument \'axis\' may not be 0.')
|
||||
|
||||
self.axis = axis
|
||||
|
||||
if isinstance(mean, variables.Variable):
|
||||
raise ValueError('Normalization does not support passing a Variable '
|
||||
'for the `mean` init arg.')
|
||||
if isinstance(variance, variables.Variable):
|
||||
raise ValueError('Normalization does not support passing a Variable '
|
||||
'for the `variance` init arg.')
|
||||
if mean is not None and variance is not None:
|
||||
mean = convert_to_ndarray(mean)
|
||||
variance = convert_to_ndarray(variance)
|
||||
elif mean is not None or variance is not None:
|
||||
raise ValueError(
|
||||
'When setting values directly, both `mean` and `variance` '
|
||||
'must be set. Got mean: {} and variance: {}'.format(mean, variance))
|
||||
|
||||
self.mean_val = mean
|
||||
self.variance_val = variance
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape).as_list()
|
||||
if len(input_shape) == 1:
|
||||
input_shape = input_shape + [1]
|
||||
|
||||
ndim = len(input_shape)
|
||||
|
||||
# Sort `self.axis` to avoid transposing `mean_and_var_shape`.
|
||||
# Negative axes are not sortable until you know the number of dimensions.
|
||||
original_axis = self.axis
|
||||
self.axis = tuple(
|
||||
sorted(self.axis, key=lambda a: a if a >= 0 else ndim + a))
|
||||
|
||||
if any(a < 1 - ndim for a in self.axis) or any(
|
||||
a >= ndim for a in self.axis):
|
||||
raise ValueError('All `axis` values must be in '
|
||||
'the range [1-ndim, ndim-1].\n'
|
||||
'Got:\n'
|
||||
' ndim: {}\n'
|
||||
' axis: {}'.format(ndim, original_axis))
|
||||
|
||||
self._broadcast_shape = [1 for _ in range(len(input_shape))]
|
||||
mean_and_var_shape = []
|
||||
for i in self.axis:
|
||||
mean_and_var_shape.append(input_shape[i])
|
||||
self._broadcast_shape[i] = input_shape[i]
|
||||
|
||||
# count is not used in this class's call() method, but is used to re-create
|
||||
# the accumulator during multiple calls to 'adapt'.
|
||||
# TODO(omalleyt): should mean and variance be set to self.dtype?
|
||||
self.mean = self._add_state_variable(
|
||||
name=_MEAN_NAME,
|
||||
shape=mean_and_var_shape,
|
||||
dtype=K.floatx(),
|
||||
initializer=init_ops.zeros_initializer)
|
||||
self.variance = self._add_state_variable(
|
||||
name=_VARIANCE_NAME,
|
||||
shape=mean_and_var_shape,
|
||||
dtype=K.floatx(),
|
||||
initializer=init_ops.ones_initializer)
|
||||
self.count = self._add_state_variable(
|
||||
name=_COUNT_NAME,
|
||||
shape=(),
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
|
||||
super(Normalization, self).build(input_shape)
|
||||
|
||||
if (self.mean_val is not None and self.variance_val is not None):
|
||||
mean_val = self.mean_val * np.ones(mean_and_var_shape)
|
||||
variance_val = self.variance_val * np.ones(mean_and_var_shape)
|
||||
self.set_weights([mean_val, variance_val])
|
||||
|
||||
def call(self, inputs):
|
||||
inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
|
||||
if inputs.shape.rank == 1:
|
||||
inputs = array_ops.expand_dims(inputs, 1)
|
||||
# If the inputs are not floats, cast them to floats. This avoids issues
|
||||
# with int-float multiplication and division below.
|
||||
if inputs.dtype != K.floatx():
|
||||
inputs = math_ops.cast(inputs, K.floatx())
|
||||
# We need to reshape the mean and variance data to ensure that Tensorflow
|
||||
# broadcasts the data correctly.
|
||||
mean = array_ops.reshape(self.mean, self._broadcast_shape)
|
||||
variance = array_ops.reshape(self.variance, self._broadcast_shape)
|
||||
return ((inputs - mean) /
|
||||
math_ops.maximum(math_ops.sqrt(variance), K.epsilon()))
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
def compute_output_signature(self, input_spec):
|
||||
return input_spec
|
||||
|
||||
def get_config(self):
|
||||
config = {'axis': self.axis}
|
||||
base_config = super(Normalization, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Override for set_weights to ensure we can set just mean/var weights."""
|
||||
if len(weights) == 2:
|
||||
weights.append(np.array(0))
|
||||
super(Normalization, self).set_weights(weights)
|
||||
|
||||
|
||||
class _NormalizingCombiner(base_preprocessing_layer.Combiner):
|
||||
"""Combiner for the Normalization preprocessing layer.
|
||||
|
||||
This class encapsulates the computations for finding the mean and variance
|
||||
of a set of data in a stable and numerically correct way. Its associated
|
||||
accumulator is a namedtuple('count', 'mean', 'variance').
|
||||
|
||||
Attributes:
|
||||
axis: The axis to compute mean and var over.
|
||||
"""
|
||||
COUNT_IDX = 0
|
||||
MEAN_IDX = 1
|
||||
VAR_IDX = 2
|
||||
|
||||
def __init__(self, axis):
|
||||
self.axis = axis
|
||||
|
||||
def compute(self, values, accumulator=None):
|
||||
"""Compute a step in this computation, returning a new accumulator."""
|
||||
values = np.array(values)
|
||||
if values.ndim == 1:
|
||||
values = np.expand_dims(values, 1)
|
||||
|
||||
# `np.delete` ignores negative indexes, so use a mask to delete items.
|
||||
axis_mask = np.ones([values.ndim], dtype=bool)
|
||||
axis_mask[np.array(self.axis, dtype=np.int32)] = False
|
||||
|
||||
# This is the shape of all reduced axes (not specified in 'axis').
|
||||
|
||||
reduction_counts = np.array(values.shape)[axis_mask]
|
||||
# We get the number of elements that will be reduced by multiplying all
|
||||
# values of 'shape' corresponding to the reduced axes.
|
||||
count = np.prod(reduction_counts, dtype=np.int64)
|
||||
|
||||
# We want to reduce across dimensions except those specified in 'axis'
|
||||
# when using np.mean or np.variance; create the tuple of axes to reduce
|
||||
# over here.
|
||||
reduction_axes = tuple(np.arange(values.ndim)[axis_mask])
|
||||
|
||||
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
|
||||
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
|
||||
|
||||
# Create an accumulator with our new data and either return it or combine
|
||||
# it with the passed accumulator.
|
||||
if accumulator is None:
|
||||
return self._create_accumulator(count, mean, variance)
|
||||
else:
|
||||
return self.add_data_to_accumulator(count, mean, variance, accumulator)
|
||||
|
||||
def add_data_to_accumulator(self, count, mean, variance, accumulator):
|
||||
"""Add new data to the totals in an accumulator."""
|
||||
# Combine accumulators and return the result.
|
||||
combined_count = count + accumulator[self.COUNT_IDX]
|
||||
|
||||
# To combine accumulator means, we weight each accumulator's mean by the
|
||||
# number of elements that were accumulated, and then divide by the
|
||||
# total number of elements.
|
||||
combined_mean = (mean * count + accumulator[self.MEAN_IDX] *
|
||||
accumulator[self.COUNT_IDX]) / combined_count
|
||||
|
||||
# The variance is computed using the lack-of-fit sum of squares
|
||||
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
||||
accumulator_var_contribution = accumulator[self.COUNT_IDX] * (
|
||||
accumulator[self.VAR_IDX] +
|
||||
np.square(accumulator[self.MEAN_IDX] - combined_mean))
|
||||
data_var_contribution = count * (variance + np.square(mean - combined_mean))
|
||||
combined_variance = (accumulator_var_contribution +
|
||||
data_var_contribution) / combined_count
|
||||
|
||||
accumulator[self.COUNT_IDX] = combined_count
|
||||
accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean)
|
||||
accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance)
|
||||
return accumulator
|
||||
|
||||
def merge(self, accumulators):
|
||||
"""Merge several accumulators to a single accumulator."""
|
||||
# Combine accumulators and return the result.
|
||||
combined_count = np.sum(
|
||||
[accumulator[self.COUNT_IDX] for accumulator in accumulators])
|
||||
|
||||
# To combine accumulator means, we weight each accumulator's mean by the
|
||||
# number of elements that were accumulated, and then divide by the
|
||||
# total number of elements.
|
||||
combined_mean = np.add.reduce([
|
||||
accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX]
|
||||
for accumulator in accumulators
|
||||
]) / combined_count
|
||||
|
||||
# The variance is computed using the lack-of-fit sum of squares
|
||||
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
||||
def variance_contribution(accumulator):
|
||||
return accumulator[self.COUNT_IDX] * (
|
||||
accumulator[self.VAR_IDX] +
|
||||
np.square(accumulator[self.MEAN_IDX] - combined_mean))
|
||||
|
||||
combined_variance = np.add.reduce([
|
||||
variance_contribution(accumulator) for accumulator in accumulators
|
||||
]) / combined_count
|
||||
|
||||
return self._create_accumulator(combined_count, combined_mean,
|
||||
combined_variance)
|
||||
|
||||
def extract(self, accumulator):
|
||||
"""Convert an accumulator into a dict of output values."""
|
||||
return {
|
||||
_COUNT_NAME: accumulator[self.COUNT_IDX],
|
||||
_MEAN_NAME: accumulator[self.MEAN_IDX],
|
||||
_VARIANCE_NAME: accumulator[self.VAR_IDX]
|
||||
}
|
||||
|
||||
def restore(self, output):
|
||||
"""Create an accumulator based on 'output'."""
|
||||
# There is no special internal state here, so we just return the relevant
|
||||
# internal value.
|
||||
count = output[_COUNT_NAME]
|
||||
mean = output[_MEAN_NAME]
|
||||
var = output[_VARIANCE_NAME]
|
||||
if (count == 0 and (mean.any() != 0.0 or var.any() != 0.0)):
|
||||
raise RuntimeError(
|
||||
'The mean and/or variance of a Normalization preprocessing layer '
|
||||
"were set without also setting 'count'. If 'count' is not also set, "
|
||||
" or was set to 0, 'adapt' cannot be called unless the 'reset_state'"
|
||||
'arg is True.')
|
||||
return self._create_accumulator(output[_COUNT_NAME], output[_MEAN_NAME],
|
||||
output[_VARIANCE_NAME])
|
||||
|
||||
def serialize(self, accumulator):
|
||||
"""Serialize an accumulator for a remote call."""
|
||||
output_dict = {
|
||||
_COUNT_NAME: accumulator[self.COUNT_IDX].tolist(),
|
||||
_MEAN_NAME: accumulator[self.MEAN_IDX].tolist(),
|
||||
_VARIANCE_NAME: accumulator[self.VAR_IDX].tolist()
|
||||
}
|
||||
return compat.as_bytes(json.dumps(output_dict))
|
||||
|
||||
def deserialize(self, encoded_accumulator):
|
||||
"""Deserialize an accumulator received from 'serialize()'."""
|
||||
value_dict = json.loads(compat.as_text(encoded_accumulator))
|
||||
return self._create_accumulator(
|
||||
np.array(value_dict[_COUNT_NAME]), np.array(value_dict[_MEAN_NAME]),
|
||||
np.array(value_dict[_VARIANCE_NAME]))
|
||||
|
||||
def _create_accumulator(self, count, mean, variance):
|
||||
"""Convert any 'nan' values in the given accumulator to numeric values."""
|
||||
return [count, mean, variance]
|
||||
|
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Normalization"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.normalization_v1.Normalization\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.normalization.Normalization\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer_v1.CombinerPreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
|
||||
|
@ -1,7 +1,6 @@
|
||||
path: "tensorflow.keras.layers.experimental.preprocessing.Normalization"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.normalization.Normalization\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user