Add support for weight broadcasting in reduction.Reduction.
PiperOrigin-RevId: 294019292 Change-Id: I44c4fb6b6e19711895749a6de7a0cfb483228014
This commit is contained in:
parent
6fa26ec475
commit
f66df8ddd2
tensorflow/python/keras/layers/preprocessing
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
@ -82,6 +83,11 @@ class Reduction(Layer):
|
||||
if weights is None:
|
||||
return get_reduce_op(self.reduction)(inputs, axis=self.axis)
|
||||
|
||||
# TODO(momernick): Add checks for this and a decent error message if the
|
||||
# weight shape isn't compatible.
|
||||
if weights.shape.rank + 1 == inputs.shape.rank:
|
||||
weights = array_ops.expand_dims(weights, -1)
|
||||
|
||||
weighted_inputs = math_ops.multiply(inputs, weights)
|
||||
|
||||
# Weighted sum and prod can be expressed as reductions over the weighted
|
||||
|
@ -108,6 +108,22 @@ class ReductionTest(keras_parameterized.TestCase):
|
||||
output = model.predict([data, weights])
|
||||
self.assertAllClose(expected_output, output)
|
||||
|
||||
def test_weighted_ragged_reduction_with_different_dimensionality(self):
|
||||
data = ragged_factory_ops.constant([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]],
|
||||
[[3.0, 1.0], [1.0, 2.0]]])
|
||||
input_tensor = keras.Input(shape=(None, None), ragged=True)
|
||||
|
||||
weights = ragged_factory_ops.constant([[1.0, 2.0, 1.0], [1.0, 1.0]])
|
||||
weight_input_tensor = keras.Input(shape=(None,), ragged=True)
|
||||
|
||||
output_tensor = reduction.Reduction(reduction="mean")(
|
||||
input_tensor, weights=weight_input_tensor)
|
||||
model = keras.Model([input_tensor, weight_input_tensor], output_tensor)
|
||||
|
||||
output = model.predict([data, weights])
|
||||
expected_output = [[2.0, 2.0], [2.0, 1.5]]
|
||||
self.assertAllClose(expected_output, output)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "max",
|
||||
@ -185,6 +201,22 @@ class ReductionTest(keras_parameterized.TestCase):
|
||||
|
||||
self.assertAllClose(expected_output, output)
|
||||
|
||||
def test_weighted_dense_reduction_with_different_dimensionality(self):
|
||||
data = np.array([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]],
|
||||
[[3.0, 1.0], [1.0, 2.0], [0.0, 0.0]]])
|
||||
input_tensor = keras.Input(shape=(None, None))
|
||||
|
||||
weights = np.array([[1.0, 2.0, 1.0], [1.0, 1.0, 0.0]])
|
||||
weight_input_tensor = keras.Input(shape=(None,))
|
||||
|
||||
output_tensor = reduction.Reduction(reduction="mean")(
|
||||
input_tensor, weights=weight_input_tensor)
|
||||
model = keras.Model([input_tensor, weight_input_tensor], output_tensor)
|
||||
|
||||
output = model.predict([data, weights])
|
||||
expected_output = [[2.0, 2.0], [2.0, 1.5]]
|
||||
self.assertAllClose(expected_output, output)
|
||||
|
||||
def test_sqrtn_fails_on_unweighted_ragged(self):
|
||||
input_tensor = keras.Input(shape=(None, None), ragged=True)
|
||||
with self.assertRaisesRegex(ValueError, ".*sqrtn.*"):
|
||||
|
Loading…
Reference in New Issue
Block a user