KPL - Implement Normalization with new APIs.

PiperOrigin-RevId: 356351476
Change-Id: I66f061a36cf096239fe2b0f2aaa380dba55a7569
This commit is contained in:
Thomas O'Malley 2021-02-08 14:17:43 -08:00 committed by TensorFlower Gardener
parent 2cc0ab3c0c
commit 236369d651
5 changed files with 479 additions and 294 deletions

View File

@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
from tensorflow.python.framework import dtypes
@ -30,26 +28,13 @@ from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import variables
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import keras_export
_COUNT_NAME = 'count'
_MEAN_NAME = 'mean'
_VARIANCE_NAME = 'variance'
def convert_to_ndarray(values):
if isinstance(values, np.ndarray):
return values
elif isinstance(values, ops.Tensor):
return K.get_value(values)
else:
return np.array(values)
@keras_export('keras.layers.experimental.preprocessing.Normalization', v1=[])
class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
class Normalization(base_preprocessing_layer.PreprocessingLayer):
"""Feature-wise normalization of the data.
This layer will coerce its inputs into a distribution centered around
@ -104,6 +89,9 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
"""
def __init__(self, axis=-1, mean=None, variance=None, **kwargs):
super(Normalization, self).__init__(stateful=True, streaming=True, **kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
# Standardize `axis` to a tuple.
if axis is None:
axis = ()
@ -111,23 +99,17 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
axis = (axis,)
else:
axis = tuple(axis)
super(Normalization, self).__init__(
combiner=_NormalizingCombiner(axis), **kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
if 0 in axis:
raise ValueError('The argument \'axis\' may not be 0.')
self.axis = axis
# Set `mean` and `variance` if passed.
if isinstance(mean, variables.Variable):
raise ValueError('Normalization does not support passing a Variable '
'for the `mean` init arg.')
if isinstance(variance, variables.Variable):
raise ValueError('Normalization does not support passing a Variable '
'for the `variance` init arg.')
if mean is not None and variance is not None:
mean = convert_to_ndarray(mean)
variance = convert_to_ndarray(variance)
@ -135,7 +117,6 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
raise ValueError(
'When setting values directly, both `mean` and `variance` '
'must be set. Got mean: {} and variance: {}'.format(mean, variance))
self.mean_val = mean
self.variance_val = variance
@ -143,62 +124,121 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if len(input_shape) == 1:
input_shape = input_shape + [1]
ndim = len(input_shape)
# Sort `self.axis` to avoid transposing `mean_and_var_shape`.
# Negative axes are not sortable until you know the number of dimensions.
original_axis = self.axis
self.axis = tuple(sorted(self.axis,
key=lambda a: a if a >= 0 else ndim + a))
if any(a < 1 - ndim or a >= ndim for a in self.axis):
raise ValueError('All `axis` values must be in the range '
'[1 - ndim, ndim - 1]. Found '
'ndim: `{}`, axis: {}'.format(ndim, self.axis))
if any(a < 1-ndim for a in self.axis) or any(a >= ndim for a in self.axis):
raise ValueError('All `axis` values must be in '
'the range [1-ndim, ndim-1].\n'
'Got:\n'
' ndim: {}\n'
' axis: {}'.format(ndim, original_axis))
# Axes to be kept, replacing negative values with positive equivalents.
# Sorted to avoid transposing axes.
self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis])
# Axes to be reduced.
self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis]
# 1 if an axis should be reduced, 0 otherwise.
self._reduce_axis_mask = [
0 if d in self._keep_axis else 1 for d in range(ndim)
]
# Broadcast any reduced axes.
self._broadcast_shape = [
input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)
]
# Create variables without keeping reduced axes.
mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)
self._broadcast_shape = [1 for _ in range(len(input_shape))]
mean_and_var_shape = []
for i in self.axis:
mean_and_var_shape.append(input_shape[i])
self._broadcast_shape[i] = input_shape[i]
# count is not used in this class's call() method, but is used to re-create
# the accumulator during multiple calls to 'adapt'.
# TODO(omalleyt): should mean and variance be set to self.dtype?
self.mean = self._add_state_variable(
name=_MEAN_NAME,
self.mean = self.add_weight(
name='mean',
shape=mean_and_var_shape,
dtype=K.floatx(),
initializer=init_ops.zeros_initializer)
self.variance = self._add_state_variable(
name=_VARIANCE_NAME,
dtype=self.dtype,
initializer=init_ops.zeros_initializer,
trainable=False)
self.variance = self.add_weight(
name='variance',
shape=mean_and_var_shape,
dtype=K.floatx(),
initializer=init_ops.ones_initializer)
self.count = self._add_state_variable(
name=_COUNT_NAME,
dtype=self.dtype,
initializer=init_ops.ones_initializer,
trainable=False)
self.count = self.add_weight(
name='count',
shape=(),
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer)
initializer=init_ops.zeros_initializer,
trainable=False)
super(Normalization, self).build(input_shape)
if (self.mean_val is not None and self.variance_val is not None):
mean_val = self.mean_val * np.ones(mean_and_var_shape)
variance_val = self.variance_val * np.ones(mean_and_var_shape)
self.set_weights([mean_val, variance_val])
self.mean.assign(mean_val)
self.variance.assign(variance_val)
self.built = True
def update_state(self, data):
if not self.built:
raise RuntimeError('`build` must be called before `update_state`.')
data = self._standardize_inputs(data)
batch_mean, batch_variance = nn_impl.moments_v2(
data, axes=self._reduce_axis)
batch_shape = array_ops.shape(data, out_type=self.count.dtype)
batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis)
batch_count = math_ops.reduce_prod(batch_reduce_shape)
total_count = batch_count + self.count
batch_weight = (
math_ops.cast(batch_count, dtype=self.dtype) /
math_ops.cast(total_count, dtype=self.dtype))
existing_weight = 1. - batch_weight
total_mean = self.mean * existing_weight + batch_mean * batch_weight
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
total_variance = ((self.variance +
(self.mean - total_mean)**2) * existing_weight +
(batch_variance +
(batch_mean - total_mean)**2) * batch_weight)
self.mean.assign(total_mean)
self.variance.assign(total_variance)
self.count.assign(total_count)
def merge_state(self, layers):
layers = layers + [self]
if any(not l.built for l in layers):
raise ValueError(
'All layers to be merged must have been adapted to some inputs '
'first (otherwise they have no state).')
layer_counts = [l.count for l in layers]
layer_means = [l.mean for l in layers]
layer_variances = [l.variance for l in layers]
total_count = math_ops.reduce_sum(layer_counts)
layer_weightings = (
math_ops.cast(layer_counts, self.dtype) /
math_ops.cast(total_count, self.dtype))
layer_weightings = array_ops.reshape(
layer_weightings, shape=[len(layers)] + [1] * self.mean.shape.rank)
total_mean = math_ops.reduce_sum(layer_means * layer_weightings, axis=0)
inter_layer_variances = (layer_means - total_mean)**2
total_variance = math_ops.reduce_sum(
((layer_variances + inter_layer_variances) * layer_weightings), axis=0)
self.mean.assign(total_mean)
self.variance.assign(total_variance)
self.count.assign(total_count)
def reset_state(self): # pylint: disable=method-hidden
if self.built:
self.mean.assign(array_ops.zeros_like(self.mean))
self.variance.assign(array_ops.ones_like(self.variance))
self.count.assign(array_ops.zeros_like(self.count))
def call(self, inputs):
inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
if inputs.shape.rank == 1:
inputs = array_ops.expand_dims(inputs, 1)
# If the inputs are not floats, cast them to floats. This avoids issues
# with int-float multiplication and division below.
if inputs.dtype != K.floatx():
inputs = math_ops.cast(inputs, K.floatx())
inputs = self._standardize_inputs(inputs)
# We need to reshape the mean and variance data to ensure that Tensorflow
# broadcasts the data correctly.
mean = array_ops.reshape(self.mean, self._broadcast_shape)
@ -213,9 +253,9 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
return input_spec
def get_config(self):
config = {'axis': self.axis}
base_config = super(Normalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
config = super(Normalization, self).get_config()
config.update({'axis': self.axis})
return config
def set_weights(self, weights):
"""Override for set_weights to ensure we can set just mean/var weights."""
@ -223,149 +263,22 @@ class Normalization(base_preprocessing_layer.CombinerPreprocessingLayer):
weights.append(np.array(0))
super(Normalization, self).set_weights(weights)
def _standardize_inputs(self, inputs):
inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
if inputs.shape.rank == 0:
inputs = array_ops.reshape(inputs, [1, 1])
elif inputs.shape.rank == 1:
inputs = array_ops.expand_dims(inputs, 1)
class _NormalizingCombiner(base_preprocessing_layer.Combiner):
"""Combiner for the Normalization preprocessing layer.
if inputs.dtype != self.dtype:
inputs = math_ops.cast(inputs, self.dtype)
return inputs
This class encapsulates the computations for finding the mean and variance
of a set of data in a stable and numerically correct way. Its associated
accumulator is a namedtuple('count', 'mean', 'variance').
Attributes:
axis: The axis to compute mean and var over.
"""
COUNT_IDX = 0
MEAN_IDX = 1
VAR_IDX = 2
def __init__(self, axis):
self.axis = axis
def compute(self, values, accumulator=None):
"""Compute a step in this computation, returning a new accumulator."""
values = np.array(values)
if values.ndim == 1:
values = np.expand_dims(values, 1)
# `np.delete` ignores negative indexes, so use a mask to delete items.
axis_mask = np.ones([values.ndim], dtype=bool)
axis_mask[np.array(self.axis, dtype=np.int32)] = False
# This is the shape of all reduced axes (not specified in 'axis').
reduction_counts = np.array(values.shape)[axis_mask]
# We get the number of elements that will be reduced by multiplying all
# values of 'shape' corresponding to the reduced axes.
count = np.prod(reduction_counts, dtype=np.int64)
# We want to reduce across dimensions except those specified in 'axis'
# when using np.mean or np.variance; create the tuple of axes to reduce
# over here.
reduction_axes = tuple(np.arange(values.ndim)[axis_mask])
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
# Create an accumulator with our new data and either return it or combine
# it with the passed accumulator.
if accumulator is None:
return self._create_accumulator(count, mean, variance)
def convert_to_ndarray(values):
if isinstance(values, np.ndarray):
return values
elif isinstance(values, ops.Tensor):
return K.get_value(values)
else:
return self.add_data_to_accumulator(count, mean, variance, accumulator)
def add_data_to_accumulator(self, count, mean, variance, accumulator):
"""Add new data to the totals in an accumulator."""
# Combine accumulators and return the result.
combined_count = count + accumulator[self.COUNT_IDX]
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = (mean * count + accumulator[self.MEAN_IDX] *
accumulator[self.COUNT_IDX]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
accumulator_var_contribution = accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
data_var_contribution = count * (variance + np.square(mean - combined_mean))
combined_variance = (accumulator_var_contribution +
data_var_contribution) / combined_count
accumulator[self.COUNT_IDX] = combined_count
accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean)
accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance)
return accumulator
def merge(self, accumulators):
"""Merge several accumulators to a single accumulator."""
# Combine accumulators and return the result.
combined_count = np.sum(
[accumulator[self.COUNT_IDX] for accumulator in accumulators])
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = np.add.reduce([
accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX]
for accumulator in accumulators
]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
def variance_contribution(accumulator):
return accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
combined_variance = np.add.reduce([
variance_contribution(accumulator) for accumulator in accumulators
]) / combined_count
return self._create_accumulator(combined_count, combined_mean,
combined_variance)
def extract(self, accumulator):
"""Convert an accumulator into a dict of output values."""
return {
_COUNT_NAME: accumulator[self.COUNT_IDX],
_MEAN_NAME: accumulator[self.MEAN_IDX],
_VARIANCE_NAME: accumulator[self.VAR_IDX]
}
def restore(self, output):
"""Create an accumulator based on 'output'."""
# There is no special internal state here, so we just return the relevant
# internal value.
count = output[_COUNT_NAME]
mean = output[_MEAN_NAME]
var = output[_VARIANCE_NAME]
if (count == 0 and (mean.any() != 0.0 or var.any() != 0.0)):
raise RuntimeError(
'The mean and/or variance of a Normalization preprocessing layer '
"were set without also setting 'count'. If 'count' is not also set, "
" or was set to 0, 'adapt' cannot be called unless the 'reset_state'"
'arg is True.')
return self._create_accumulator(output[_COUNT_NAME], output[_MEAN_NAME],
output[_VARIANCE_NAME])
def serialize(self, accumulator):
"""Serialize an accumulator for a remote call."""
output_dict = {
_COUNT_NAME: accumulator[self.COUNT_IDX].tolist(),
_MEAN_NAME: accumulator[self.MEAN_IDX].tolist(),
_VARIANCE_NAME: accumulator[self.VAR_IDX].tolist()
}
return compat.as_bytes(json.dumps(output_dict))
def deserialize(self, encoded_accumulator):
"""Deserialize an accumulator received from 'serialize()'."""
value_dict = json.loads(compat.as_text(encoded_accumulator))
return self._create_accumulator(
np.array(value_dict[_COUNT_NAME]), np.array(value_dict[_MEAN_NAME]),
np.array(value_dict[_VARIANCE_NAME]))
def _create_accumulator(self, count, mean, variance):
"""Convert any 'nan' values in the given accumulator to numeric values."""
return [count, mean, variance]
return np.array(values)

View File

@ -141,80 +141,6 @@ class NormalizationTest(keras_parameterized.TestCase,
expected = np.array([[3., -3., -0.33333333], [9., 5., 1.]])
self.assertAllClose(expected, output_data)
def test_combiner_api_compatibility(self):
data = np.array([[1], [2], [3], [4], [5]])
combiner = normalization._NormalizingCombiner(axis=-1)
expected = {
"count": np.array(5.0),
"variance": np.array([2.]),
"mean": np.array([3.])
}
expected_accumulator = combiner._create_accumulator(expected["count"],
expected["mean"],
expected["variance"])
self.validate_accumulator_serialize_and_deserialize(combiner, data,
expected_accumulator)
self.validate_accumulator_uniqueness(combiner, data)
self.validate_accumulator_extract(combiner, data, expected)
self.validate_accumulator_extract_and_restore(combiner, data,
expected)
@parameterized.named_parameters(
{
"data": np.array([[1], [2], [3], [4], [5]]),
"axis": -1,
"expected": {
"count": np.array(5.0),
"variance": np.array([2.]),
"mean": np.array([3.])
},
"testcase_name": "2d_single_element"
}, {
"data": np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]),
"axis": -1,
"expected": {
"count": np.array(5.0),
"mean": np.array([3., 4.]),
"variance": np.array([2., 2.])
},
"testcase_name": "2d_multi_element"
}, {
"data": np.array([[[1, 2]], [[2, 3]], [[3, 4]], [[4, 5]], [[5, 6]]]),
"axis": 2,
"expected": {
"count": np.array(5.0),
"mean": np.array([3., 4.]),
"variance": np.array([2., 2.])
},
"testcase_name": "3d_multi_element"
}, {
"data": np.array([[[1, 2]], [[2, 3]], [[3, 4]], [[4, 5]], [[5, 6]]]),
"axis": (1, 2),
"expected": {
"count": np.array(5.0),
"mean": np.array([[3., 4.]]),
"variance": np.array([[2., 2.]])
},
"testcase_name": "3d_multi_element_multi_axis"
}, {
"data":
np.array([[[1, 2], [2, 3]], [[3, 4], [4, 5]], [[1, 2], [2, 3]],
[[3, 4], [4, 5]]]),
"axis":
1,
"expected": {
"count": np.array(8.0),
"mean": np.array([2.5, 3.5]),
"variance": np.array([1.25, 1.25])
},
"testcase_name":
"3d_multi_element_internal_axis"
})
def test_combiner_computation_multi_value_axis(self, data, axis, expected):
combiner = normalization._NormalizingCombiner(axis=axis)
expected_accumulator = combiner._create_accumulator(**expected)
self.validate_accumulator_computation(combiner, data, expected_accumulator)
@parameterized.named_parameters(*_get_layer_computation_test_cases())
def test_layer_computation(self, adapt_data, axis, test_data, use_dataset,
expected):
@ -286,6 +212,18 @@ class NormalizationTest(keras_parameterized.TestCase,
if context.executing_eagerly():
self.assertAllClose(output.numpy(), [[-1], [1], [-1], [1]])
def test_0d_data(self):
if not context.executing_eagerly():
self.skipTest("Only supported in TF2.")
data = [0, 2, 0, 2]
cls = get_layer_class()
layer = cls(axis=-1)
layer.adapt(data)
output = layer(0.)
self.assertListEqual(output.shape.as_list(), [1, 1])
self.assertAllClose(output.numpy(), [[-1]])
@parameterized.parameters(
{"axis": 0},
{"axis": (-1, 0)},
@ -307,8 +245,7 @@ class NormalizationTest(keras_parameterized.TestCase,
def test_bad_axis_fail_build(self, axis):
cls = get_layer_class()
layer = cls(axis=axis)
with self.assertRaisesRegex(ValueError,
r"in the range \[1-ndim, ndim-1\]."):
with self.assertRaisesRegex(ValueError, r"in the range"):
layer.build([None, 2, 3])
@parameterized.parameters(
@ -341,6 +278,33 @@ class NormalizationTest(keras_parameterized.TestCase,
keras.layers.Dense(1)])
model.summary()
def test_merge_state(self):
if not context.executing_eagerly():
self.skipTest("`merge_state` only supported in TF2")
cls = get_layer_class()
data = np.random.rand(30, 10, 2)
ds = dataset_ops.Dataset.from_tensor_slices(data).batch(2)
norm = cls(axis=(1, 2))
norm.adapt(ds)
partial_ds_1 = ds.shard(3, 0)
partial_ds_2 = ds.shard(3, 1)
partial_ds_3 = ds.shard(3, 2)
norm_1 = cls(axis=(1, 2))
norm_2 = cls(axis=(1, 2))
norm_3 = cls(axis=(1, 2))
norm_1.adapt(partial_ds_1)
norm_2.adapt(partial_ds_2)
norm_3.adapt(partial_ds_3)
norm_1.merge_state([norm_2, norm_3])
merged_norm = norm_1
self.assertAllClose(norm(data), merged_norm(data))
if __name__ == "__main__":
test.main()

View File

@ -18,16 +18,326 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_preprocessing_layer
from tensorflow.python.keras.engine.base_preprocessing_layer_v1 import CombinerPreprocessingLayer
from tensorflow.python.keras.layers.preprocessing import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import keras_export
_COUNT_NAME = 'count'
_MEAN_NAME = 'mean'
_VARIANCE_NAME = 'variance'
def convert_to_ndarray(values):
if isinstance(values, np.ndarray):
return values
elif isinstance(values, ops.Tensor):
return K.get_value(values)
else:
return np.array(values)
@keras_export(v1=['keras.layers.experimental.preprocessing.Normalization'])
class Normalization(normalization.Normalization, CombinerPreprocessingLayer):
class Normalization(CombinerPreprocessingLayer):
"""Feature-wise normalization of the data.
This layer will coerce its inputs into a distribution centered around
0 with standard deviation 1. It accomplishes this by precomputing the mean and
variance of the data, and calling (input-mean)/sqrt(var) at runtime.
What happens in `adapt`: Compute mean and variance of the data and store them
as the layer's weights. `adapt` should be called before `fit`, `evaluate`,
or `predict`.
Attributes:
axis: Integer or tuple of integers, the axis or axes that should be
"kept". These axes are not be summed over when calculating the
normalization statistics. By default the last axis, the `features` axis
is kept and any `space` or `time` axes are summed. Each element in the
the axes that are kept is normalized independently. If `axis` is set to
'None', the layer will perform scalar normalization (dividing the input
by a single scalar value). The `batch` axis, 0, is always summed over
(`axis=0` is not allowed).
"""
def __init__(self, axis=-1, **kwargs):
super(Normalization, self).__init__(axis, **kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell(
'Normalization v1').set(True)
# Standardize `axis` to a tuple.
if axis is None:
axis = ()
elif isinstance(axis, int):
axis = (axis,)
else:
axis = tuple(axis)
mean = kwargs.pop('mean', None)
variance = kwargs.pop('variance', None)
super(Normalization, self).__init__(
combiner=_NormalizingCombiner(axis), **kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell('Normalization').set(True)
if 0 in axis:
raise ValueError('The argument \'axis\' may not be 0.')
self.axis = axis
if isinstance(mean, variables.Variable):
raise ValueError('Normalization does not support passing a Variable '
'for the `mean` init arg.')
if isinstance(variance, variables.Variable):
raise ValueError('Normalization does not support passing a Variable '
'for the `variance` init arg.')
if mean is not None and variance is not None:
mean = convert_to_ndarray(mean)
variance = convert_to_ndarray(variance)
elif mean is not None or variance is not None:
raise ValueError(
'When setting values directly, both `mean` and `variance` '
'must be set. Got mean: {} and variance: {}'.format(mean, variance))
self.mean_val = mean
self.variance_val = variance
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if len(input_shape) == 1:
input_shape = input_shape + [1]
ndim = len(input_shape)
# Sort `self.axis` to avoid transposing `mean_and_var_shape`.
# Negative axes are not sortable until you know the number of dimensions.
original_axis = self.axis
self.axis = tuple(
sorted(self.axis, key=lambda a: a if a >= 0 else ndim + a))
if any(a < 1 - ndim for a in self.axis) or any(
a >= ndim for a in self.axis):
raise ValueError('All `axis` values must be in '
'the range [1-ndim, ndim-1].\n'
'Got:\n'
' ndim: {}\n'
' axis: {}'.format(ndim, original_axis))
self._broadcast_shape = [1 for _ in range(len(input_shape))]
mean_and_var_shape = []
for i in self.axis:
mean_and_var_shape.append(input_shape[i])
self._broadcast_shape[i] = input_shape[i]
# count is not used in this class's call() method, but is used to re-create
# the accumulator during multiple calls to 'adapt'.
# TODO(omalleyt): should mean and variance be set to self.dtype?
self.mean = self._add_state_variable(
name=_MEAN_NAME,
shape=mean_and_var_shape,
dtype=K.floatx(),
initializer=init_ops.zeros_initializer)
self.variance = self._add_state_variable(
name=_VARIANCE_NAME,
shape=mean_and_var_shape,
dtype=K.floatx(),
initializer=init_ops.ones_initializer)
self.count = self._add_state_variable(
name=_COUNT_NAME,
shape=(),
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer)
super(Normalization, self).build(input_shape)
if (self.mean_val is not None and self.variance_val is not None):
mean_val = self.mean_val * np.ones(mean_and_var_shape)
variance_val = self.variance_val * np.ones(mean_and_var_shape)
self.set_weights([mean_val, variance_val])
def call(self, inputs):
inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
if inputs.shape.rank == 1:
inputs = array_ops.expand_dims(inputs, 1)
# If the inputs are not floats, cast them to floats. This avoids issues
# with int-float multiplication and division below.
if inputs.dtype != K.floatx():
inputs = math_ops.cast(inputs, K.floatx())
# We need to reshape the mean and variance data to ensure that Tensorflow
# broadcasts the data correctly.
mean = array_ops.reshape(self.mean, self._broadcast_shape)
variance = array_ops.reshape(self.variance, self._broadcast_shape)
return ((inputs - mean) /
math_ops.maximum(math_ops.sqrt(variance), K.epsilon()))
def compute_output_shape(self, input_shape):
return input_shape
def compute_output_signature(self, input_spec):
return input_spec
def get_config(self):
config = {'axis': self.axis}
base_config = super(Normalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def set_weights(self, weights):
"""Override for set_weights to ensure we can set just mean/var weights."""
if len(weights) == 2:
weights.append(np.array(0))
super(Normalization, self).set_weights(weights)
class _NormalizingCombiner(base_preprocessing_layer.Combiner):
"""Combiner for the Normalization preprocessing layer.
This class encapsulates the computations for finding the mean and variance
of a set of data in a stable and numerically correct way. Its associated
accumulator is a namedtuple('count', 'mean', 'variance').
Attributes:
axis: The axis to compute mean and var over.
"""
COUNT_IDX = 0
MEAN_IDX = 1
VAR_IDX = 2
def __init__(self, axis):
self.axis = axis
def compute(self, values, accumulator=None):
"""Compute a step in this computation, returning a new accumulator."""
values = np.array(values)
if values.ndim == 1:
values = np.expand_dims(values, 1)
# `np.delete` ignores negative indexes, so use a mask to delete items.
axis_mask = np.ones([values.ndim], dtype=bool)
axis_mask[np.array(self.axis, dtype=np.int32)] = False
# This is the shape of all reduced axes (not specified in 'axis').
reduction_counts = np.array(values.shape)[axis_mask]
# We get the number of elements that will be reduced by multiplying all
# values of 'shape' corresponding to the reduced axes.
count = np.prod(reduction_counts, dtype=np.int64)
# We want to reduce across dimensions except those specified in 'axis'
# when using np.mean or np.variance; create the tuple of axes to reduce
# over here.
reduction_axes = tuple(np.arange(values.ndim)[axis_mask])
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
# Create an accumulator with our new data and either return it or combine
# it with the passed accumulator.
if accumulator is None:
return self._create_accumulator(count, mean, variance)
else:
return self.add_data_to_accumulator(count, mean, variance, accumulator)
def add_data_to_accumulator(self, count, mean, variance, accumulator):
"""Add new data to the totals in an accumulator."""
# Combine accumulators and return the result.
combined_count = count + accumulator[self.COUNT_IDX]
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = (mean * count + accumulator[self.MEAN_IDX] *
accumulator[self.COUNT_IDX]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
accumulator_var_contribution = accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
data_var_contribution = count * (variance + np.square(mean - combined_mean))
combined_variance = (accumulator_var_contribution +
data_var_contribution) / combined_count
accumulator[self.COUNT_IDX] = combined_count
accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean)
accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance)
return accumulator
def merge(self, accumulators):
"""Merge several accumulators to a single accumulator."""
# Combine accumulators and return the result.
combined_count = np.sum(
[accumulator[self.COUNT_IDX] for accumulator in accumulators])
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = np.add.reduce([
accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX]
for accumulator in accumulators
]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
def variance_contribution(accumulator):
return accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
combined_variance = np.add.reduce([
variance_contribution(accumulator) for accumulator in accumulators
]) / combined_count
return self._create_accumulator(combined_count, combined_mean,
combined_variance)
def extract(self, accumulator):
"""Convert an accumulator into a dict of output values."""
return {
_COUNT_NAME: accumulator[self.COUNT_IDX],
_MEAN_NAME: accumulator[self.MEAN_IDX],
_VARIANCE_NAME: accumulator[self.VAR_IDX]
}
def restore(self, output):
"""Create an accumulator based on 'output'."""
# There is no special internal state here, so we just return the relevant
# internal value.
count = output[_COUNT_NAME]
mean = output[_MEAN_NAME]
var = output[_VARIANCE_NAME]
if (count == 0 and (mean.any() != 0.0 or var.any() != 0.0)):
raise RuntimeError(
'The mean and/or variance of a Normalization preprocessing layer '
"were set without also setting 'count'. If 'count' is not also set, "
" or was set to 0, 'adapt' cannot be called unless the 'reset_state'"
'arg is True.')
return self._create_accumulator(output[_COUNT_NAME], output[_MEAN_NAME],
output[_VARIANCE_NAME])
def serialize(self, accumulator):
"""Serialize an accumulator for a remote call."""
output_dict = {
_COUNT_NAME: accumulator[self.COUNT_IDX].tolist(),
_MEAN_NAME: accumulator[self.MEAN_IDX].tolist(),
_VARIANCE_NAME: accumulator[self.VAR_IDX].tolist()
}
return compat.as_bytes(json.dumps(output_dict))
def deserialize(self, encoded_accumulator):
"""Deserialize an accumulator received from 'serialize()'."""
value_dict = json.loads(compat.as_text(encoded_accumulator))
return self._create_accumulator(
np.array(value_dict[_COUNT_NAME]), np.array(value_dict[_MEAN_NAME]),
np.array(value_dict[_VARIANCE_NAME]))
def _create_accumulator(self, count, mean, variance):
"""Convert any 'nan' values in the given accumulator to numeric values."""
return [count, mean, variance]

View File

@ -1,7 +1,6 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Normalization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.normalization_v1.Normalization\'>"
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.normalization.Normalization\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer_v1.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"

View File

@ -1,7 +1,6 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Normalization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.normalization.Normalization\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"