From fd0c36cf58eafedb1b3f276bdcbc0afbff48de73 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 30 Mar 2020 10:58:18 -0700 Subject: [PATCH] Benchmark and improve Normalize's 'adapt' method. PiperOrigin-RevId: 303776842 Change-Id: I2a3ea3ffed80aa37ae5dedc20c3f3eeb0c633604 --- .../layers/preprocessing/benchmarks/BUILD | 10 ++ .../normalization_adapt_benchmark.py | 133 ++++++++++++++++++ .../layers/preprocessing/normalization.py | 73 ++++++---- 3 files changed, 191 insertions(+), 25 deletions(-) create mode 100644 tensorflow/python/keras/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD index 716c598466d..f488e1da34f 100644 --- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD @@ -16,3 +16,13 @@ tf_py_test( "//tensorflow/python/keras/layers/preprocessing:index_lookup", ], ) + +tf_py_test( + name = "normalization_adapt_benchmark", + srcs = ["normalization_adapt_benchmark.py"], + python_version = "PY3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/keras/layers/preprocessing:normalization", + ], +) diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py new file mode 100644 index 00000000000..dfce2963f75 --- /dev/null +++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/normalization_adapt_benchmark.py @@ -0,0 +1,133 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Keras text vectorization preprocessing layer's adapt method.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl import flags +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.compat import v2_compat +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.keras.layers.preprocessing import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import test + +FLAGS = flags.FLAGS + +v2_compat.enable_v2_behavior() + + +def reduce_fn(state, values): + """tf.data.Dataset-friendly implementation of mean and variance.""" + k, n, ex, ex2 = state + # If this is the first iteration, we pick the first value to be 'k', + # which helps with precision - we assume that k is close to an average + # value and calculate mean and variance with respect to that. + k = control_flow_ops.cond(math_ops.equal(n, 0), lambda: values[0], lambda: k) + + sum_v = math_ops.reduce_sum(values, axis=0) + sum_v2 = math_ops.reduce_sum(math_ops.square(values), axis=0) + ones = array_ops.ones_like(values, dtype=dtypes.int32) + batch_size = math_ops.reduce_sum(ones, axis=0) + batch_size_f = math_ops.cast(batch_size, dtypes.float32) + + ex = 0 + sum_v - math_ops.multiply(batch_size_f, k) + ex2 = 0 + sum_v2 + math_ops.multiply( + batch_size_f, (math_ops.square(k) - + math_ops.multiply(math_ops.multiply(2.0, k), sum_v))) + + return (k, n + batch_size, ex, ex2) + + +class BenchmarkAdapt(benchmark.Benchmark): + """Benchmark adapt.""" + + def run_dataset_implementation(self, num_elements, batch_size): + input_t = keras.Input(shape=(1,)) + layer = normalization.Normalization() + _ = layer(input_t) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.range(num_elements) + ds = ds.map( + lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1)) + ds = ds.batch(batch_size) + + starts.append(time.time()) + # Benchmarked code begins here. + k, n, ex, ex2 = ds.reduce((0.0, 0, 0.0, 0.0), reduce_fn) + mean = k.numpy() + ex.numpy() / n.numpy() + var = (ex2.numpy() - (ex.numpy() * ex.numpy()) / n.numpy()) / ( + n.numpy() - 1) + layer.set_weights([mean, var]) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) + return avg_time + + def bm_adapt_implementation(self, num_elements, batch_size): + """Test the KPL adapt implementation.""" + input_t = keras.Input(shape=(1,), dtype=dtypes.float32) + layer = normalization.Normalization() + _ = layer(input_t) + + num_repeats = 5 + starts = [] + ends = [] + for _ in range(num_repeats): + ds = dataset_ops.Dataset.range(num_elements) + ds = ds.map( + lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1)) + ds = ds.batch(batch_size) + + starts.append(time.time()) + # Benchmarked code begins here. + layer.adapt(ds) + # Benchmarked code ends here. + ends.append(time.time()) + + avg_time = np.mean(np.array(ends) - np.array(starts)) + name = "normalization_adapt|%s_elements|batch_%s" % (num_elements, + batch_size) + baseline = self.run_dataset_implementation(num_elements, batch_size) + extras = { + "tf.data implementation baseline": baseline, + "delta seconds": (baseline - avg_time), + "delta percent": ((baseline - avg_time) / baseline) * 100 + } + self.report_benchmark( + iters=num_repeats, wall_time=avg_time, extras=extras, name=name) + + def benchmark_vocab_size_by_batch(self): + for vocab_size in [100, 1000, 10000, 100000, 1000000]: + for batch in [1, 16, 2048]: + self.bm_adapt_implementation(vocab_size, batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/preprocessing/normalization.py b/tensorflow/python/keras/layers/preprocessing/normalization.py index 00ee2adf70d..5a0b8990486 100644 --- a/tensorflow/python/keras/layers/preprocessing/normalization.py +++ b/tensorflow/python/keras/layers/preprocessing/normalization.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import json import numpy as np @@ -53,9 +52,9 @@ class Normalization(CombinerPreprocessingLayer): Attributes: axis: Integer or tuple of integers, the axis or axes that should be normalized (typically the features axis). We will normalize each element - in the specified axis. If set to 'None', the layer will perform - scalar normalization (diving the input by a single scalar value). - 0 (the batch axis) is not allowed. + in the specified axis. If set to 'None', the layer will perform scalar + normalization (diving the input by a single scalar value). 0 (the batch + axis) is not allowed. """ def __init__(self, axis=-1, dtype=None, **kwargs): @@ -132,12 +131,6 @@ class Normalization(CombinerPreprocessingLayer): super(Normalization, self).set_weights(weights) -class _NormalizingAccumulator( - collections.namedtuple('_NormalizingAccumulator', - ['count', 'mean', 'variance'])): - pass - - class _NormalizingCombiner(Combiner): """Combiner for the Normalization preprocessing layer. @@ -148,6 +141,9 @@ class _NormalizingCombiner(Combiner): Attributes: axis: The axis to compute mean and var over. """ + COUNT_IDX = 0 + MEAN_IDX = 1 + VAR_IDX = 2 def __init__(self, axis): self.axis = axis @@ -171,34 +167,62 @@ class _NormalizingCombiner(Combiner): reduction_axes = None else: reduction_axes = tuple(np.delete(range(values.ndim), self.axis)) + mean = np.mean(values, axis=reduction_axes, dtype=np.float64) variance = np.var(values, axis=reduction_axes, dtype=np.float64) # Create an accumulator with our new data and either return it or combine # it with the passed accumulator. - sanitized_accumulator = self._create_accumulator(count, mean, variance) if accumulator is None: - return sanitized_accumulator + return self._create_accumulator(count, mean, variance) else: - return self.merge([accumulator, sanitized_accumulator]) + return self.add_data_to_accumulator(count, mean, variance, accumulator) + + def add_data_to_accumulator(self, count, mean, variance, accumulator): + """Add new data to the totals in an accumulator.""" + # Combine accumulators and return the result. + combined_count = count + accumulator[self.COUNT_IDX] + + # To combine accumulator means, we weight each accumulator's mean by the + # number of elements that were accumulated, and then divide by the + # total number of elements. + combined_mean = (mean * count + accumulator[self.MEAN_IDX] * + accumulator[self.COUNT_IDX]) / combined_count + + # The variance is computed using the lack-of-fit sum of squares + # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). + accumulator_var_contribution = accumulator[self.COUNT_IDX] * ( + accumulator[self.VAR_IDX] + + np.square(accumulator[self.MEAN_IDX] - combined_mean)) + data_var_contribution = count * (variance + np.square(mean - combined_mean)) + combined_variance = (accumulator_var_contribution + + data_var_contribution) / combined_count + + accumulator[self.COUNT_IDX] = combined_count + accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean) + accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance) + return accumulator def merge(self, accumulators): """Merge several accumulators to a single accumulator.""" # Combine accumulators and return the result. - combined_count = np.sum([accumulator.count for accumulator in accumulators]) + combined_count = np.sum( + [accumulator[self.COUNT_IDX] for accumulator in accumulators]) # To combine accumulator means, we weight each accumulator's mean by the # number of elements that were accumulated, and then divide by the # total number of elements. combined_mean = np.add.reduce([ - accumulator.mean * accumulator.count for accumulator in accumulators + accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX] + for accumulator in accumulators ]) / combined_count # The variance is computed using the lack-of-fit sum of squares # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). def variance_contribution(accumulator): - return accumulator.count * ( - accumulator.variance + np.square(accumulator.mean - combined_mean)) + return accumulator[self.COUNT_IDX] * ( + accumulator[self.VAR_IDX] + + np.square(accumulator[self.MEAN_IDX] - combined_mean)) combined_variance = np.add.reduce([ variance_contribution(accumulator) for accumulator in accumulators @@ -210,9 +234,9 @@ class _NormalizingCombiner(Combiner): def extract(self, accumulator): """Convert an accumulator into a dict of output values.""" return { - _COUNT_NAME: accumulator.count, - _MEAN_NAME: accumulator.mean, - _VARIANCE_NAME: accumulator.variance + _COUNT_NAME: accumulator[self.COUNT_IDX], + _MEAN_NAME: accumulator[1], + _VARIANCE_NAME: accumulator[2] } def restore(self, output): @@ -233,9 +257,9 @@ class _NormalizingCombiner(Combiner): def serialize(self, accumulator): """Serialize an accumulator for a remote call.""" output_dict = { - _COUNT_NAME: accumulator.count.tolist(), - _MEAN_NAME: accumulator.mean.tolist(), - _VARIANCE_NAME: accumulator.variance.tolist() + _COUNT_NAME: accumulator[self.COUNT_IDX].tolist(), + _MEAN_NAME: accumulator[1].tolist(), + _VARIANCE_NAME: accumulator[2].tolist() } return compat.as_bytes(json.dumps(output_dict)) @@ -248,5 +272,4 @@ class _NormalizingCombiner(Combiner): def _create_accumulator(self, count, mean, variance): """Convert any 'nan' values in the given accumulator to numeric values.""" - return _NormalizingAccumulator( - np.array(count), np.nan_to_num(mean), np.nan_to_num(variance)) + return [count, mean, variance]