118 lines
4.7 KiB
Python
118 lines
4.7 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Keras reduction layer."""
|
|
# pylint: disable=g-classes-have-attributes
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.keras.engine.base_layer import Layer
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
|
|
|
|
def get_reduce_op(reduction_str):
|
|
"""Translate a reduction string name to a reduction op."""
|
|
if reduction_str == "max":
|
|
return math_ops.reduce_max
|
|
elif reduction_str == "mean":
|
|
return math_ops.reduce_mean
|
|
elif reduction_str == "min":
|
|
return math_ops.reduce_min
|
|
elif reduction_str == "prod":
|
|
return math_ops.reduce_prod
|
|
elif reduction_str == "sum":
|
|
return math_ops.reduce_sum
|
|
else:
|
|
raise ValueError("Reduction %s is not supported for unweighted inputs." %
|
|
reduction_str)
|
|
|
|
|
|
class Reduction(Layer):
|
|
"""Performs an optionally-weighted reduction.
|
|
|
|
This layer performs a reduction across one axis of its input data. This
|
|
data may optionally be weighted by passing in an identical float tensor.
|
|
|
|
Args:
|
|
reduction: The type of reduction to perform. Can be one of the following:
|
|
"max", "mean", "min", "prod", or "sum". This layer uses the Tensorflow
|
|
reduce op which corresponds to that reduction (so, for "mean", we use
|
|
"reduce_mean").
|
|
axis: The axis to reduce along. Defaults to '-2', which is usually the axis
|
|
that contains embeddings (but is not within the embedding itself).
|
|
|
|
Input shape:
|
|
A tensor of 2 or more dimensions of any numeric dtype.
|
|
|
|
Output:
|
|
A tensor of 1 less dimension than the input tensor, of the same dtype.
|
|
|
|
Call arguments:
|
|
inputs: The data to reduce.
|
|
weights: An optional tensor or constant of the same shape as inputs that
|
|
will weight the input data before it is reduced.
|
|
"""
|
|
# TODO(momernick): Add example here.
|
|
|
|
def __init__(self, reduction, axis=-2, **kwargs):
|
|
self.reduction = reduction
|
|
self.axis = axis
|
|
# We temporarily turn off autocasting, as it does not apply to named call
|
|
# kwargs.
|
|
super(Reduction, self).__init__(**kwargs)
|
|
|
|
def call(self, inputs, weights=None):
|
|
# If we are not weighting the inputs we can immediately reduce the data
|
|
# and return it.
|
|
if weights is None:
|
|
return get_reduce_op(self.reduction)(inputs, axis=self.axis)
|
|
|
|
# TODO(momernick): Add checks for this and a decent error message if the
|
|
# weight shape isn't compatible.
|
|
if weights.shape.rank + 1 == inputs.shape.rank:
|
|
weights = array_ops.expand_dims(weights, -1)
|
|
|
|
weighted_inputs = math_ops.multiply(inputs, weights)
|
|
|
|
# Weighted sum and prod can be expressed as reductions over the weighted
|
|
# values, as can min and max.
|
|
if self.reduction in ("sum", "prod", "min", "max"):
|
|
return get_reduce_op(self.reduction)(weighted_inputs, axis=self.axis)
|
|
|
|
# Weighted mean is a bit more complicated: we have to do a sum of the
|
|
# weighted values and divide by the sum of the weights.
|
|
if self.reduction == "mean":
|
|
input_sum = math_ops.reduce_sum(weighted_inputs, axis=self.axis)
|
|
weight_sum = math_ops.reduce_sum(weights, axis=self.axis)
|
|
return math_ops.divide(input_sum, weight_sum)
|
|
|
|
# sqrtn is also more complicated: it's like mean but with a normalized
|
|
# divisor.
|
|
if self.reduction == "sqrtn":
|
|
logging.warning("Reduction `sqrtn` is deprecated and will be removed "
|
|
"2021-01-01. Please use the `sum` reduction and divide "
|
|
"the output by the normalized weights instead.")
|
|
input_sum = math_ops.reduce_sum(weighted_inputs, axis=self.axis)
|
|
squared_weights = math_ops.pow(weights, 2)
|
|
squared_weights_sum = math_ops.reduce_sum(squared_weights, axis=self.axis)
|
|
sqrt_weights_sum = math_ops.sqrt(squared_weights_sum)
|
|
return math_ops.divide(input_sum, sqrt_weights_sum)
|
|
|
|
raise ValueError("%s is not a supported weighted reduction." %
|
|
self.reduction)
|