Add support for weight broadcasting in reduction.Reduction.

PiperOrigin-RevId: 294019292
Change-Id: I44c4fb6b6e19711895749a6de7a0cfb483228014
This commit is contained in:
A. Unique TensorFlower 2020-02-08 14:19:32 -08:00 committed by TensorFlower Gardener
parent 6fa26ec475
commit f66df8ddd2
2 changed files with 38 additions and 0 deletions
tensorflow/python/keras/layers/preprocessing

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
@ -82,6 +83,11 @@ class Reduction(Layer):
if weights is None:
return get_reduce_op(self.reduction)(inputs, axis=self.axis)
# TODO(momernick): Add checks for this and a decent error message if the
# weight shape isn't compatible.
if weights.shape.rank + 1 == inputs.shape.rank:
weights = array_ops.expand_dims(weights, -1)
weighted_inputs = math_ops.multiply(inputs, weights)
# Weighted sum and prod can be expressed as reductions over the weighted

View File

@ -108,6 +108,22 @@ class ReductionTest(keras_parameterized.TestCase):
output = model.predict([data, weights])
self.assertAllClose(expected_output, output)
def test_weighted_ragged_reduction_with_different_dimensionality(self):
data = ragged_factory_ops.constant([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]],
[[3.0, 1.0], [1.0, 2.0]]])
input_tensor = keras.Input(shape=(None, None), ragged=True)
weights = ragged_factory_ops.constant([[1.0, 2.0, 1.0], [1.0, 1.0]])
weight_input_tensor = keras.Input(shape=(None,), ragged=True)
output_tensor = reduction.Reduction(reduction="mean")(
input_tensor, weights=weight_input_tensor)
model = keras.Model([input_tensor, weight_input_tensor], output_tensor)
output = model.predict([data, weights])
expected_output = [[2.0, 2.0], [2.0, 1.5]]
self.assertAllClose(expected_output, output)
@parameterized.named_parameters(
{
"testcase_name": "max",
@ -185,6 +201,22 @@ class ReductionTest(keras_parameterized.TestCase):
self.assertAllClose(expected_output, output)
def test_weighted_dense_reduction_with_different_dimensionality(self):
data = np.array([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]],
[[3.0, 1.0], [1.0, 2.0], [0.0, 0.0]]])
input_tensor = keras.Input(shape=(None, None))
weights = np.array([[1.0, 2.0, 1.0], [1.0, 1.0, 0.0]])
weight_input_tensor = keras.Input(shape=(None,))
output_tensor = reduction.Reduction(reduction="mean")(
input_tensor, weights=weight_input_tensor)
model = keras.Model([input_tensor, weight_input_tensor], output_tensor)
output = model.predict([data, weights])
expected_output = [[2.0, 2.0], [2.0, 1.5]]
self.assertAllClose(expected_output, output)
def test_sqrtn_fails_on_unweighted_ragged(self):
input_tensor = keras.Input(shape=(None, None), ragged=True)
with self.assertRaisesRegex(ValueError, ".*sqrtn.*"):