Update metrics to better support tensors that are not fully defined
If a tensor is not fully defined, broadcast_weights calls DenseToDenseSetOperation which isn't supported by TPUs. In that case, metrics should fallback to a less performant but supported path. Also disable the mlir bridge presubmit testing because the new tests fails on the mlir bridge. It isn't possible to use @test_util.disable_mlir_bridge() because it conflicts with @ds_combinations.generate. PiperOrigin-RevId: 340676883 Change-Id: I965fee2185605542a177599ccf47238f10d15da2
This commit is contained in:
parent
1dcdf0d7ff
commit
a3c9e996fb
@ -240,7 +240,6 @@ distribute_py_test(
|
|||||||
distribute_py_test(
|
distribute_py_test(
|
||||||
name = "custom_training_loop_metrics_test",
|
name = "custom_training_loop_metrics_test",
|
||||||
srcs = ["custom_training_loop_metrics_test.py"],
|
srcs = ["custom_training_loop_metrics_test.py"],
|
||||||
disable_mlir_bridge = False,
|
|
||||||
main = "custom_training_loop_metrics_test.py",
|
main = "custom_training_loop_metrics_test.py",
|
||||||
tags = [
|
tags = [
|
||||||
"multi_and_single_gpu",
|
"multi_and_single_gpu",
|
||||||
|
@ -100,6 +100,31 @@ class KerasMetricsTest(test.TestCase, parameterized.TestCase):
|
|||||||
# of 10 resulting in mean of 4.5.
|
# of 10 resulting in mean of 4.5.
|
||||||
self.assertEqual(metric.result().numpy(), 4.5)
|
self.assertEqual(metric.result().numpy(), 4.5)
|
||||||
|
|
||||||
|
@ds_combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies, mode=["eager"]))
|
||||||
|
def test_update_keras_metrics_dynamic_shape(self, distribution):
|
||||||
|
with distribution.scope():
|
||||||
|
metric = metrics.Mean("test_metric", dtype=np.float32)
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10).batch(2, drop_remainder=False)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def train_fn(dataset):
|
||||||
|
weights = constant_op.constant([0.1, 0.1])
|
||||||
|
|
||||||
|
def step_fn(i):
|
||||||
|
metric.update_state(i, weights)
|
||||||
|
|
||||||
|
for i in dataset:
|
||||||
|
distribution.run(step_fn, args=(i,))
|
||||||
|
|
||||||
|
train_fn(dataset)
|
||||||
|
|
||||||
|
# This should be the mean of integers 0-9 which has a sum of 45 and a count
|
||||||
|
# of 10 resulting in mean of 4.5.
|
||||||
|
self.assertEqual(metric.result().numpy(), 4.5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
multi_process_runner.test_main()
|
multi_process_runner.test_main()
|
||||||
|
@ -383,20 +383,33 @@ class Reduce(Metric):
|
|||||||
# Update dimensions of weights to match with values if possible.
|
# Update dimensions of weights to match with values if possible.
|
||||||
values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
|
values, _, sample_weight = losses_utils.squeeze_or_expand_dimensions(
|
||||||
values, sample_weight=sample_weight)
|
values, sample_weight=sample_weight)
|
||||||
try:
|
|
||||||
# Broadcast weights if possible.
|
def reduce_values(values, sample_weight):
|
||||||
sample_weight = weights_broadcast_ops.broadcast_weights(
|
|
||||||
sample_weight, values)
|
|
||||||
except ValueError:
|
|
||||||
# Reduce values to same ndim as weight array
|
# Reduce values to same ndim as weight array
|
||||||
ndim = K.ndim(values)
|
ndim = K.ndim(values)
|
||||||
weight_ndim = K.ndim(sample_weight)
|
weight_ndim = K.ndim(sample_weight)
|
||||||
if self.reduction == metrics_utils.Reduction.SUM:
|
if self.reduction == metrics_utils.Reduction.SUM:
|
||||||
values = math_ops.reduce_sum(
|
return math_ops.reduce_sum(
|
||||||
values, axis=list(range(weight_ndim, ndim)))
|
values, axis=list(range(weight_ndim, ndim)))
|
||||||
else:
|
else:
|
||||||
values = math_ops.reduce_mean(
|
return math_ops.reduce_mean(
|
||||||
values, axis=list(range(weight_ndim, ndim)))
|
values, axis=list(range(weight_ndim, ndim)))
|
||||||
|
|
||||||
|
# Attempt to broadcast weights if possible because it is more performant.
|
||||||
|
# When a tensor shape is not fully defined, use the fallback path because
|
||||||
|
# broadcast_weights with a dynamic shape input calls a GPU/TPU
|
||||||
|
# incompatible op (DenseToDenseSetOperation).
|
||||||
|
is_cpu_tensor = values.device and 'CPU' in values.device
|
||||||
|
if is_cpu_tensor or values.get_shape().is_fully_defined(
|
||||||
|
) or K.ndim(values) is None:
|
||||||
|
try:
|
||||||
|
sample_weight = weights_broadcast_ops.broadcast_weights(
|
||||||
|
sample_weight, values)
|
||||||
|
except ValueError:
|
||||||
|
values = reduce_values(values, sample_weight)
|
||||||
|
else:
|
||||||
|
values = reduce_values(values, sample_weight)
|
||||||
|
|
||||||
values = math_ops.multiply(values, sample_weight)
|
values = math_ops.multiply(values, sample_weight)
|
||||||
|
|
||||||
value_sum = math_ops.reduce_sum(values)
|
value_sum = math_ops.reduce_sum(values)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user