Benchmark and improve Normalize's 'adapt' method.

PiperOrigin-RevId: 303776842
Change-Id: I2a3ea3ffed80aa37ae5dedc20c3f3eeb0c633604
This commit is contained in:
A. Unique TensorFlower 2020-03-30 10:58:18 -07:00 committed by TensorFlower Gardener
parent 238d1a70a7
commit fd0c36cf58
3 changed files with 191 additions and 25 deletions

View File

@ -16,3 +16,13 @@ tf_py_test(
"//tensorflow/python/keras/layers/preprocessing:index_lookup",
],
)
tf_py_test(
name = "normalization_adapt_benchmark",
srcs = ["normalization_adapt_benchmark.py"],
python_version = "PY3",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python/keras/layers/preprocessing:normalization",
],
)

View File

@ -0,0 +1,133 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Keras text vectorization preprocessing layer's adapt method."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
from absl import flags
import numpy as np
from tensorflow.python import keras
from tensorflow.python.compat import v2_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.keras.layers.preprocessing import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
FLAGS = flags.FLAGS
v2_compat.enable_v2_behavior()
def reduce_fn(state, values):
"""tf.data.Dataset-friendly implementation of mean and variance."""
k, n, ex, ex2 = state
# If this is the first iteration, we pick the first value to be 'k',
# which helps with precision - we assume that k is close to an average
# value and calculate mean and variance with respect to that.
k = control_flow_ops.cond(math_ops.equal(n, 0), lambda: values[0], lambda: k)
sum_v = math_ops.reduce_sum(values, axis=0)
sum_v2 = math_ops.reduce_sum(math_ops.square(values), axis=0)
ones = array_ops.ones_like(values, dtype=dtypes.int32)
batch_size = math_ops.reduce_sum(ones, axis=0)
batch_size_f = math_ops.cast(batch_size, dtypes.float32)
ex = 0 + sum_v - math_ops.multiply(batch_size_f, k)
ex2 = 0 + sum_v2 + math_ops.multiply(
batch_size_f, (math_ops.square(k) -
math_ops.multiply(math_ops.multiply(2.0, k), sum_v)))
return (k, n + batch_size, ex, ex2)
class BenchmarkAdapt(benchmark.Benchmark):
"""Benchmark adapt."""
def run_dataset_implementation(self, num_elements, batch_size):
input_t = keras.Input(shape=(1,))
layer = normalization.Normalization()
_ = layer(input_t)
num_repeats = 5
starts = []
ends = []
for _ in range(num_repeats):
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.map(
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
ds = ds.batch(batch_size)
starts.append(time.time())
# Benchmarked code begins here.
k, n, ex, ex2 = ds.reduce((0.0, 0, 0.0, 0.0), reduce_fn)
mean = k.numpy() + ex.numpy() / n.numpy()
var = (ex2.numpy() - (ex.numpy() * ex.numpy()) / n.numpy()) / (
n.numpy() - 1)
layer.set_weights([mean, var])
# Benchmarked code ends here.
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts))
return avg_time
def bm_adapt_implementation(self, num_elements, batch_size):
"""Test the KPL adapt implementation."""
input_t = keras.Input(shape=(1,), dtype=dtypes.float32)
layer = normalization.Normalization()
_ = layer(input_t)
num_repeats = 5
starts = []
ends = []
for _ in range(num_repeats):
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.map(
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
ds = ds.batch(batch_size)
starts.append(time.time())
# Benchmarked code begins here.
layer.adapt(ds)
# Benchmarked code ends here.
ends.append(time.time())
avg_time = np.mean(np.array(ends) - np.array(starts))
name = "normalization_adapt|%s_elements|batch_%s" % (num_elements,
batch_size)
baseline = self.run_dataset_implementation(num_elements, batch_size)
extras = {
"tf.data implementation baseline": baseline,
"delta seconds": (baseline - avg_time),
"delta percent": ((baseline - avg_time) / baseline) * 100
}
self.report_benchmark(
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
def benchmark_vocab_size_by_batch(self):
for vocab_size in [100, 1000, 10000, 100000, 1000000]:
for batch in [1, 16, 2048]:
self.bm_adapt_implementation(vocab_size, batch)
if __name__ == "__main__":
test.main()

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
import numpy as np
@ -53,9 +52,9 @@ class Normalization(CombinerPreprocessingLayer):
Attributes:
axis: Integer or tuple of integers, the axis or axes that should be
normalized (typically the features axis). We will normalize each element
in the specified axis. If set to 'None', the layer will perform
scalar normalization (diving the input by a single scalar value).
0 (the batch axis) is not allowed.
in the specified axis. If set to 'None', the layer will perform scalar
normalization (diving the input by a single scalar value). 0 (the batch
axis) is not allowed.
"""
def __init__(self, axis=-1, dtype=None, **kwargs):
@ -132,12 +131,6 @@ class Normalization(CombinerPreprocessingLayer):
super(Normalization, self).set_weights(weights)
class _NormalizingAccumulator(
collections.namedtuple('_NormalizingAccumulator',
['count', 'mean', 'variance'])):
pass
class _NormalizingCombiner(Combiner):
"""Combiner for the Normalization preprocessing layer.
@ -148,6 +141,9 @@ class _NormalizingCombiner(Combiner):
Attributes:
axis: The axis to compute mean and var over.
"""
COUNT_IDX = 0
MEAN_IDX = 1
VAR_IDX = 2
def __init__(self, axis):
self.axis = axis
@ -171,34 +167,62 @@ class _NormalizingCombiner(Combiner):
reduction_axes = None
else:
reduction_axes = tuple(np.delete(range(values.ndim), self.axis))
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
# Create an accumulator with our new data and either return it or combine
# it with the passed accumulator.
sanitized_accumulator = self._create_accumulator(count, mean, variance)
if accumulator is None:
return sanitized_accumulator
return self._create_accumulator(count, mean, variance)
else:
return self.merge([accumulator, sanitized_accumulator])
return self.add_data_to_accumulator(count, mean, variance, accumulator)
def add_data_to_accumulator(self, count, mean, variance, accumulator):
"""Add new data to the totals in an accumulator."""
# Combine accumulators and return the result.
combined_count = count + accumulator[self.COUNT_IDX]
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = (mean * count + accumulator[self.MEAN_IDX] *
accumulator[self.COUNT_IDX]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
accumulator_var_contribution = accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
data_var_contribution = count * (variance + np.square(mean - combined_mean))
combined_variance = (accumulator_var_contribution +
data_var_contribution) / combined_count
accumulator[self.COUNT_IDX] = combined_count
accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean)
accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance)
return accumulator
def merge(self, accumulators):
"""Merge several accumulators to a single accumulator."""
# Combine accumulators and return the result.
combined_count = np.sum([accumulator.count for accumulator in accumulators])
combined_count = np.sum(
[accumulator[self.COUNT_IDX] for accumulator in accumulators])
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = np.add.reduce([
accumulator.mean * accumulator.count for accumulator in accumulators
accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX]
for accumulator in accumulators
]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
def variance_contribution(accumulator):
return accumulator.count * (
accumulator.variance + np.square(accumulator.mean - combined_mean))
return accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
combined_variance = np.add.reduce([
variance_contribution(accumulator) for accumulator in accumulators
@ -210,9 +234,9 @@ class _NormalizingCombiner(Combiner):
def extract(self, accumulator):
"""Convert an accumulator into a dict of output values."""
return {
_COUNT_NAME: accumulator.count,
_MEAN_NAME: accumulator.mean,
_VARIANCE_NAME: accumulator.variance
_COUNT_NAME: accumulator[self.COUNT_IDX],
_MEAN_NAME: accumulator[1],
_VARIANCE_NAME: accumulator[2]
}
def restore(self, output):
@ -233,9 +257,9 @@ class _NormalizingCombiner(Combiner):
def serialize(self, accumulator):
"""Serialize an accumulator for a remote call."""
output_dict = {
_COUNT_NAME: accumulator.count.tolist(),
_MEAN_NAME: accumulator.mean.tolist(),
_VARIANCE_NAME: accumulator.variance.tolist()
_COUNT_NAME: accumulator[self.COUNT_IDX].tolist(),
_MEAN_NAME: accumulator[1].tolist(),
_VARIANCE_NAME: accumulator[2].tolist()
}
return compat.as_bytes(json.dumps(output_dict))
@ -248,5 +272,4 @@ class _NormalizingCombiner(Combiner):
def _create_accumulator(self, count, mean, variance):
"""Convert any 'nan' values in the given accumulator to numeric values."""
return _NormalizingAccumulator(
np.array(count), np.nan_to_num(mean), np.nan_to_num(variance))
return [count, mean, variance]