Update metrics to better support tensors that are not fully defined

If a tensor is not fully defined, broadcast_weights calls
DenseToDenseSetOperation which isn't supported by TPUs. In that
case, metrics should fallback to a less performant but supported
path.

Also disable the mlir bridge presubmit testing because the new
tests fails on the mlir bridge. It isn't possible to use
@test_util.disable_mlir_bridge() because it conflicts with
@ds_combinations.generate.

PiperOrigin-RevId: 340676883
Change-Id: I965fee2185605542a177599ccf47238f10d15da2
This commit is contained in:
Marissa Ikonomidis 2020-11-04 10:12:21 -08:00 committed by TensorFlower Gardener
parent 1dcdf0d7ff
commit a3c9e996fb
3 changed files with 45 additions and 8 deletions

View File

@ -240,7 +240,6 @@ distribute_py_test(
distribute_py_test( distribute_py_test(
name = "custom_training_loop_metrics_test", name = "custom_training_loop_metrics_test",
srcs = ["custom_training_loop_metrics_test.py"], srcs = ["custom_training_loop_metrics_test.py"],
disable_mlir_bridge = False,
main = "custom_training_loop_metrics_test.py", main = "custom_training_loop_metrics_test.py",
tags = [ tags = [
"multi_and_single_gpu", "multi_and_single_gpu",

View File

@ -100,6 +100,31 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
# of 10 resulting in mean of 4.5. # of 10 resulting in mean of 4.5.
self.assertEqual(metric.result().numpy(), 4.5) self.assertEqual(metric.result().numpy(), 4.5)
@ds_combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies, mode=["eager"]))
def test_update_keras_metrics_dynamic_shape(self, distribution):
with distribution.scope():
metric = metrics.Mean("test_metric", dtype=np.float32)
dataset = dataset_ops.Dataset.range(10).batch(2, drop_remainder=False)
@def_function.function
def train_fn(dataset):
weights = constant_op.constant([0.1, 0.1])
def step_fn(i):
metric.update_state(i, weights)
for i in dataset:
distribution.run(step_fn, args=(i,))
train_fn(dataset)
# This should be the mean of integers 0-9 which has a sum of 45 and a count
# of 10 resulting in mean of 4.5.
self.assertEqual(metric.result().numpy(), 4.5)
if __name__ == "__main__": if __name__ == "__main__":
multi_process_runner.test_main() multi_process_runner.test_main()

View File

@ -383,20 +383,33 @@ class Reduce(Metric):
# Update dimensions of weights to match with values if possible. # Update dimensions of weights to match with values if possible.
values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions( values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
values, sample_weight=sample_weight) values, sample_weight=sample_weight)
try:
# Broadcast weights if possible. def reduce_values(values, sample_weight):
sample_weight = weights_broadcast_ops.broadcast_weights(
sample_weight, values)
except ValueError:
# Reduce values to same ndim as weight array # Reduce values to same ndim as weight array
ndim = K.ndim(values) ndim = K.ndim(values)
weight_ndim = K.ndim(sample_weight) weight_ndim = K.ndim(sample_weight)
if self.reduction == metrics_utils.Reduction.SUM: if self.reduction == metrics_utils.Reduction.SUM:
values = math_ops.reduce_sum( return math_ops.reduce_sum(
values, axis=list(range(weight_ndim, ndim))) values, axis=list(range(weight_ndim, ndim)))
else: else:
values = math_ops.reduce_mean( return math_ops.reduce_mean(
values, axis=list(range(weight_ndim, ndim))) values, axis=list(range(weight_ndim, ndim)))
# Attempt to broadcast weights if possible because it is more performant.
# When a tensor shape is not fully defined, use the fallback path because
# broadcast_weights with a dynamic shape input calls a GPU/TPU
# incompatible op (DenseToDenseSetOperation).
is_cpu_tensor = values.device and 'CPU' in values.device
if is_cpu_tensor or values.get_shape().is_fully_defined(
) or K.ndim(values) is None:
try:
sample_weight = weights_broadcast_ops.broadcast_weights(
sample_weight, values)
except ValueError:
values = reduce_values(values, sample_weight)
else:
values = reduce_values(values, sample_weight)
values = math_ops.multiply(values, sample_weight) values = math_ops.multiply(values, sample_weight)
value_sum = math_ops.reduce_sum(values) value_sum = math_ops.reduce_sum(values)