From a3c9e996fb912dd42eeeb501048af65e2a09814e Mon Sep 17 00:00:00 2001 From: Marissa Ikonomidis Date: Wed, 4 Nov 2020 10:12:21 -0800 Subject: [PATCH] Update metrics to better support tensors that are not fully defined If a tensor is not fully defined, broadcast_weights calls DenseToDenseSetOperation which isn't supported by TPUs. In that case, metrics should fallback to a less performant but supported path. Also disable the mlir bridge presubmit testing because the new tests fails on the mlir bridge. It isn't possible to use @test_util.disable_mlir_bridge() because it conflicts with @ds_combinations.generate. PiperOrigin-RevId: 340676883 Change-Id: I965fee2185605542a177599ccf47238f10d15da2 --- tensorflow/python/keras/distribute/BUILD | 1 - .../custom_training_loop_metrics_test.py | 25 +++++++++++++++++ tensorflow/python/keras/metrics.py | 27 ++++++++++++++----- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index f52b63defb4..9f3a42eaabe 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -240,7 +240,6 @@ distribute_py_test( distribute_py_test( name = "custom_training_loop_metrics_test", srcs = ["custom_training_loop_metrics_test.py"], - disable_mlir_bridge = False, main = "custom_training_loop_metrics_test.py", tags = [ "multi_and_single_gpu", diff --git a/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py b/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py index 0ad69699d64..0bb34eee487 100644 --- a/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py +++ b/tensorflow/python/keras/distribute/custom_training_loop_metrics_test.py @@ -100,6 +100,31 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase): # of 10 resulting in mean of 4.5. self.assertEqual(metric.result().numpy(), 4.5) + @ds_combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, mode=["eager"])) + def test_update_keras_metrics_dynamic_shape(self, distribution): + with distribution.scope(): + metric = metrics.Mean("test_metric", dtype=np.float32) + + dataset = dataset_ops.Dataset.range(10).batch(2, drop_remainder=False) + + @def_function.function + def train_fn(dataset): + weights = constant_op.constant([0.1, 0.1]) + + def step_fn(i): + metric.update_state(i, weights) + + for i in dataset: + distribution.run(step_fn, args=(i,)) + + train_fn(dataset) + + # This should be the mean of integers 0-9 which has a sum of 45 and a count + # of 10 resulting in mean of 4.5. + self.assertEqual(metric.result().numpy(), 4.5) + if __name__ == "__main__": multi_process_runner.test_main() diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py index f26b797db64..6cb51381083 100644 --- a/tensorflow/python/keras/metrics.py +++ b/tensorflow/python/keras/metrics.py @@ -383,20 +383,33 @@ class Reduce(Metric): # Update dimensions of weights to match with values if possible. values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( values, sample_weight=sample_weight) - try: - # Broadcast weights if possible. - sample_weight = weights_broadcast_ops.broadcast_weights( - sample_weight, values) - except ValueError: + + def reduce_values(values, sample_weight): # Reduce values to same ndim as weight array ndim = K.ndim(values) weight_ndim = K.ndim(sample_weight) if self.reduction == metrics_utils.Reduction.SUM: - values = math_ops.reduce_sum( + return math_ops.reduce_sum( values, axis=list(range(weight_ndim, ndim))) else: - values = math_ops.reduce_mean( + return math_ops.reduce_mean( values, axis=list(range(weight_ndim, ndim))) + + # Attempt to broadcast weights if possible because it is more performant. + # When a tensor shape is not fully defined, use the fallback path because + # broadcast_weights with a dynamic shape input calls a GPU/TPU + # incompatible op (DenseToDenseSetOperation). + is_cpu_tensor = values.device and 'CPU' in values.device + if is_cpu_tensor or values.get_shape().is_fully_defined( + ) or K.ndim(values) is None: + try: + sample_weight = weights_broadcast_ops.broadcast_weights( + sample_weight, values) + except ValueError: + values = reduce_values(values, sample_weight) + else: + values = reduce_values(values, sample_weight) + values = math_ops.multiply(values, sample_weight) value_sum = math_ops.reduce_sum(values)