Benchmark and improve Normalize's 'adapt' method.
PiperOrigin-RevId: 303776842 Change-Id: I2a3ea3ffed80aa37ae5dedc20c3f3eeb0c633604
This commit is contained in:
parent
238d1a70a7
commit
fd0c36cf58
@ -16,3 +16,13 @@ tf_py_test(
|
||||
"//tensorflow/python/keras/layers/preprocessing:index_lookup",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "normalization_adapt_benchmark",
|
||||
srcs = ["normalization_adapt_benchmark.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python/keras/layers/preprocessing:normalization",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,133 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Benchmark for Keras text vectorization preprocessing layer's adapt method."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras.layers.preprocessing import normalization
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import benchmark
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
v2_compat.enable_v2_behavior()
|
||||
|
||||
|
||||
def reduce_fn(state, values):
|
||||
"""tf.data.Dataset-friendly implementation of mean and variance."""
|
||||
k, n, ex, ex2 = state
|
||||
# If this is the first iteration, we pick the first value to be 'k',
|
||||
# which helps with precision - we assume that k is close to an average
|
||||
# value and calculate mean and variance with respect to that.
|
||||
k = control_flow_ops.cond(math_ops.equal(n, 0), lambda: values[0], lambda: k)
|
||||
|
||||
sum_v = math_ops.reduce_sum(values, axis=0)
|
||||
sum_v2 = math_ops.reduce_sum(math_ops.square(values), axis=0)
|
||||
ones = array_ops.ones_like(values, dtype=dtypes.int32)
|
||||
batch_size = math_ops.reduce_sum(ones, axis=0)
|
||||
batch_size_f = math_ops.cast(batch_size, dtypes.float32)
|
||||
|
||||
ex = 0 + sum_v - math_ops.multiply(batch_size_f, k)
|
||||
ex2 = 0 + sum_v2 + math_ops.multiply(
|
||||
batch_size_f, (math_ops.square(k) -
|
||||
math_ops.multiply(math_ops.multiply(2.0, k), sum_v)))
|
||||
|
||||
return (k, n + batch_size, ex, ex2)
|
||||
|
||||
|
||||
class BenchmarkAdapt(benchmark.Benchmark):
|
||||
"""Benchmark adapt."""
|
||||
|
||||
def run_dataset_implementation(self, num_elements, batch_size):
|
||||
input_t = keras.Input(shape=(1,))
|
||||
layer = normalization.Normalization()
|
||||
_ = layer(input_t)
|
||||
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
for _ in range(num_repeats):
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.map(
|
||||
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
|
||||
ds = ds.batch(batch_size)
|
||||
|
||||
starts.append(time.time())
|
||||
# Benchmarked code begins here.
|
||||
k, n, ex, ex2 = ds.reduce((0.0, 0, 0.0, 0.0), reduce_fn)
|
||||
mean = k.numpy() + ex.numpy() / n.numpy()
|
||||
var = (ex2.numpy() - (ex.numpy() * ex.numpy()) / n.numpy()) / (
|
||||
n.numpy() - 1)
|
||||
layer.set_weights([mean, var])
|
||||
# Benchmarked code ends here.
|
||||
ends.append(time.time())
|
||||
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||
return avg_time
|
||||
|
||||
def bm_adapt_implementation(self, num_elements, batch_size):
|
||||
"""Test the KPL adapt implementation."""
|
||||
input_t = keras.Input(shape=(1,), dtype=dtypes.float32)
|
||||
layer = normalization.Normalization()
|
||||
_ = layer(input_t)
|
||||
|
||||
num_repeats = 5
|
||||
starts = []
|
||||
ends = []
|
||||
for _ in range(num_repeats):
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.map(
|
||||
lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
|
||||
ds = ds.batch(batch_size)
|
||||
|
||||
starts.append(time.time())
|
||||
# Benchmarked code begins here.
|
||||
layer.adapt(ds)
|
||||
# Benchmarked code ends here.
|
||||
ends.append(time.time())
|
||||
|
||||
avg_time = np.mean(np.array(ends) - np.array(starts))
|
||||
name = "normalization_adapt|%s_elements|batch_%s" % (num_elements,
|
||||
batch_size)
|
||||
baseline = self.run_dataset_implementation(num_elements, batch_size)
|
||||
extras = {
|
||||
"tf.data implementation baseline": baseline,
|
||||
"delta seconds": (baseline - avg_time),
|
||||
"delta percent": ((baseline - avg_time) / baseline) * 100
|
||||
}
|
||||
self.report_benchmark(
|
||||
iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
|
||||
|
||||
def benchmark_vocab_size_by_batch(self):
|
||||
for vocab_size in [100, 1000, 10000, 100000, 1000000]:
|
||||
for batch in [1, 16, 2048]:
|
||||
self.bm_adapt_implementation(vocab_size, batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
@ -53,9 +52,9 @@ class Normalization(CombinerPreprocessingLayer):
|
||||
Attributes:
|
||||
axis: Integer or tuple of integers, the axis or axes that should be
|
||||
normalized (typically the features axis). We will normalize each element
|
||||
in the specified axis. If set to 'None', the layer will perform
|
||||
scalar normalization (diving the input by a single scalar value).
|
||||
0 (the batch axis) is not allowed.
|
||||
in the specified axis. If set to 'None', the layer will perform scalar
|
||||
normalization (diving the input by a single scalar value). 0 (the batch
|
||||
axis) is not allowed.
|
||||
"""
|
||||
|
||||
def __init__(self, axis=-1, dtype=None, **kwargs):
|
||||
@ -132,12 +131,6 @@ class Normalization(CombinerPreprocessingLayer):
|
||||
super(Normalization, self).set_weights(weights)
|
||||
|
||||
|
||||
class _NormalizingAccumulator(
|
||||
collections.namedtuple('_NormalizingAccumulator',
|
||||
['count', 'mean', 'variance'])):
|
||||
pass
|
||||
|
||||
|
||||
class _NormalizingCombiner(Combiner):
|
||||
"""Combiner for the Normalization preprocessing layer.
|
||||
|
||||
@ -148,6 +141,9 @@ class _NormalizingCombiner(Combiner):
|
||||
Attributes:
|
||||
axis: The axis to compute mean and var over.
|
||||
"""
|
||||
COUNT_IDX = 0
|
||||
MEAN_IDX = 1
|
||||
VAR_IDX = 2
|
||||
|
||||
def __init__(self, axis):
|
||||
self.axis = axis
|
||||
@ -171,34 +167,62 @@ class _NormalizingCombiner(Combiner):
|
||||
reduction_axes = None
|
||||
else:
|
||||
reduction_axes = tuple(np.delete(range(values.ndim), self.axis))
|
||||
|
||||
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
|
||||
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
|
||||
|
||||
# Create an accumulator with our new data and either return it or combine
|
||||
# it with the passed accumulator.
|
||||
sanitized_accumulator = self._create_accumulator(count, mean, variance)
|
||||
if accumulator is None:
|
||||
return sanitized_accumulator
|
||||
return self._create_accumulator(count, mean, variance)
|
||||
else:
|
||||
return self.merge([accumulator, sanitized_accumulator])
|
||||
return self.add_data_to_accumulator(count, mean, variance, accumulator)
|
||||
|
||||
def add_data_to_accumulator(self, count, mean, variance, accumulator):
|
||||
"""Add new data to the totals in an accumulator."""
|
||||
# Combine accumulators and return the result.
|
||||
combined_count = count + accumulator[self.COUNT_IDX]
|
||||
|
||||
# To combine accumulator means, we weight each accumulator's mean by the
|
||||
# number of elements that were accumulated, and then divide by the
|
||||
# total number of elements.
|
||||
combined_mean = (mean * count + accumulator[self.MEAN_IDX] *
|
||||
accumulator[self.COUNT_IDX]) / combined_count
|
||||
|
||||
# The variance is computed using the lack-of-fit sum of squares
|
||||
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
||||
accumulator_var_contribution = accumulator[self.COUNT_IDX] * (
|
||||
accumulator[self.VAR_IDX] +
|
||||
np.square(accumulator[self.MEAN_IDX] - combined_mean))
|
||||
data_var_contribution = count * (variance + np.square(mean - combined_mean))
|
||||
combined_variance = (accumulator_var_contribution +
|
||||
data_var_contribution) / combined_count
|
||||
|
||||
accumulator[self.COUNT_IDX] = combined_count
|
||||
accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean)
|
||||
accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance)
|
||||
return accumulator
|
||||
|
||||
def merge(self, accumulators):
|
||||
"""Merge several accumulators to a single accumulator."""
|
||||
# Combine accumulators and return the result.
|
||||
combined_count = np.sum([accumulator.count for accumulator in accumulators])
|
||||
combined_count = np.sum(
|
||||
[accumulator[self.COUNT_IDX] for accumulator in accumulators])
|
||||
|
||||
# To combine accumulator means, we weight each accumulator's mean by the
|
||||
# number of elements that were accumulated, and then divide by the
|
||||
# total number of elements.
|
||||
combined_mean = np.add.reduce([
|
||||
accumulator.mean * accumulator.count for accumulator in accumulators
|
||||
accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX]
|
||||
for accumulator in accumulators
|
||||
]) / combined_count
|
||||
|
||||
# The variance is computed using the lack-of-fit sum of squares
|
||||
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
|
||||
def variance_contribution(accumulator):
|
||||
return accumulator.count * (
|
||||
accumulator.variance + np.square(accumulator.mean - combined_mean))
|
||||
return accumulator[self.COUNT_IDX] * (
|
||||
accumulator[self.VAR_IDX] +
|
||||
np.square(accumulator[self.MEAN_IDX] - combined_mean))
|
||||
|
||||
combined_variance = np.add.reduce([
|
||||
variance_contribution(accumulator) for accumulator in accumulators
|
||||
@ -210,9 +234,9 @@ class _NormalizingCombiner(Combiner):
|
||||
def extract(self, accumulator):
|
||||
"""Convert an accumulator into a dict of output values."""
|
||||
return {
|
||||
_COUNT_NAME: accumulator.count,
|
||||
_MEAN_NAME: accumulator.mean,
|
||||
_VARIANCE_NAME: accumulator.variance
|
||||
_COUNT_NAME: accumulator[self.COUNT_IDX],
|
||||
_MEAN_NAME: accumulator[1],
|
||||
_VARIANCE_NAME: accumulator[2]
|
||||
}
|
||||
|
||||
def restore(self, output):
|
||||
@ -233,9 +257,9 @@ class _NormalizingCombiner(Combiner):
|
||||
def serialize(self, accumulator):
|
||||
"""Serialize an accumulator for a remote call."""
|
||||
output_dict = {
|
||||
_COUNT_NAME: accumulator.count.tolist(),
|
||||
_MEAN_NAME: accumulator.mean.tolist(),
|
||||
_VARIANCE_NAME: accumulator.variance.tolist()
|
||||
_COUNT_NAME: accumulator[self.COUNT_IDX].tolist(),
|
||||
_MEAN_NAME: accumulator[1].tolist(),
|
||||
_VARIANCE_NAME: accumulator[2].tolist()
|
||||
}
|
||||
return compat.as_bytes(json.dumps(output_dict))
|
||||
|
||||
@ -248,5 +272,4 @@ class _NormalizingCombiner(Combiner):
|
||||
|
||||
def _create_accumulator(self, count, mean, variance):
|
||||
"""Convert any 'nan' values in the given accumulator to numeric values."""
|
||||
return _NormalizingAccumulator(
|
||||
np.array(count), np.nan_to_num(mean), np.nan_to_num(variance))
|
||||
return [count, mean, variance]
|
||||
|
Loading…
x
Reference in New Issue
Block a user