Fix issue - 1. tf.keras.Model.compile metrics do not respect masking since TF 2.2
2. Prevent passing a loss as metric into model.compile PiperOrigin-RevId: 314962578 Change-Id: If9c69173ba2ca6d475c2b3d06bdaffad72e29f7b
This commit is contained in:
parent
1c7fb569b4
commit
1c144fdb67
@ -200,8 +200,7 @@ class LossesContainer(Container):
|
||||
continue
|
||||
|
||||
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
|
||||
sw = apply_mask(y_p, sw)
|
||||
|
||||
sw = apply_mask(y_p, sw, get_mask(y_p))
|
||||
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
|
||||
|
||||
loss_metric_value = loss_value
|
||||
@ -401,12 +400,13 @@ class MetricsContainer(Container):
|
||||
continue
|
||||
|
||||
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
|
||||
sw = apply_mask(y_p, sw)
|
||||
mask = get_mask(y_p)
|
||||
sw = apply_mask(y_p, sw, mask)
|
||||
|
||||
for metric_obj in metric_objs:
|
||||
if metric_obj is None:
|
||||
continue
|
||||
metric_obj.update_state(y_t, y_p)
|
||||
metric_obj.update_state(y_t, y_p, sample_weight=mask)
|
||||
|
||||
for weighted_metric_obj in weighted_metric_objs:
|
||||
if weighted_metric_obj is None:
|
||||
@ -461,6 +461,9 @@ class MetricsContainer(Container):
|
||||
else:
|
||||
metric_obj = metrics_mod.categorical_crossentropy
|
||||
|
||||
if isinstance(metric_obj, losses_mod.Loss):
|
||||
metric_obj._allow_sum_over_batch_size = True # pylint: disable=protected-access
|
||||
|
||||
if not isinstance(metric_obj, metrics_mod.Metric):
|
||||
if isinstance(metric, six.string_types):
|
||||
metric_name = metric
|
||||
@ -620,10 +623,13 @@ def match_dtype_and_rank(y_t, y_p, sw):
|
||||
return y_t, y_p, sw
|
||||
|
||||
|
||||
def apply_mask(y_p, sw):
|
||||
def get_mask(y_p):
|
||||
"""Returns Keras mask from tensor."""
|
||||
return getattr(y_p, '_keras_mask', None)
|
||||
|
||||
|
||||
def apply_mask(y_p, sw, mask):
|
||||
"""Applies any mask on predictions to sample weights."""
|
||||
# Handle Keras mask on outputs.
|
||||
mask = getattr(y_p, '_keras_mask', None)
|
||||
if mask is not None:
|
||||
mask = math_ops.cast(mask, y_p.dtype)
|
||||
if sw is not None:
|
||||
|
@ -18,11 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import one_device_strategy
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import losses as losses_mod
|
||||
from tensorflow.python.keras import metrics as metrics_mod
|
||||
from tensorflow.python.keras.engine import compile_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -289,6 +291,53 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
||||
total_loss = loss_container(y_t, y_p)
|
||||
self.assertEqual(total_loss.dtype, dtypes.float64)
|
||||
|
||||
def test_loss_masking(self):
|
||||
loss_container = compile_utils.LossesContainer('mae')
|
||||
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
|
||||
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
|
||||
y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
|
||||
dtype=dtypes.float32)
|
||||
|
||||
total_loss = loss_container(y_t, y_p)
|
||||
self.assertAlmostEqual(total_loss.numpy(), .25) # sum over batch size
|
||||
|
||||
self.assertLen(loss_container.metrics, 1)
|
||||
loss_metric = loss_container.metrics[0]
|
||||
self.assertEqual(loss_metric.name, 'loss')
|
||||
self.assertAlmostEqual(loss_metric.result().numpy(), .25)
|
||||
|
||||
def test_loss_sample_weight(self):
|
||||
loss_container = compile_utils.LossesContainer('mae')
|
||||
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
|
||||
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
|
||||
sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
|
||||
|
||||
total_loss = loss_container(y_t, y_p, sample_weight=sw)
|
||||
# (0 * .2 + 0 * .3 + 1 * .5 + 1 * 0) / 4
|
||||
self.assertAlmostEqual(total_loss.numpy(), .125)
|
||||
|
||||
self.assertLen(loss_container.metrics, 1)
|
||||
loss_metric = loss_container.metrics[0]
|
||||
self.assertEqual(loss_metric.name, 'loss')
|
||||
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
|
||||
|
||||
def test_loss_masking_sample_weight(self):
|
||||
loss_container = compile_utils.LossesContainer('mae')
|
||||
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
|
||||
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
|
||||
sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
|
||||
y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
|
||||
dtype=dtypes.float32)
|
||||
|
||||
total_loss = loss_container(y_t, y_p, sample_weight=sw)
|
||||
# (0 * .2 + 1 * .5) / 4
|
||||
self.assertAlmostEqual(total_loss.numpy(), .125) # sum over batch size
|
||||
|
||||
self.assertLen(loss_container.metrics, 1)
|
||||
loss_metric = loss_container.metrics[0]
|
||||
self.assertEqual(loss_metric.name, 'loss')
|
||||
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
|
||||
|
||||
|
||||
class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
|
||||
@ -566,6 +615,76 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(mse_metric.name, 'output3_mse')
|
||||
self.assertEqual(mse_metric.result().numpy(), 4.)
|
||||
|
||||
def test_metrics_masking(self):
|
||||
metrics_container = compile_utils.MetricsContainer(
|
||||
metrics=['mae'], weighted_metrics=['mse'])
|
||||
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
|
||||
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
|
||||
y_p._keras_mask = constant_op.constant([[1, 1], [0, 0]],
|
||||
dtype=dtypes.float32)
|
||||
|
||||
metrics_container.update_state(y_t, y_p)
|
||||
self.assertLen(metrics_container.metrics, 2)
|
||||
|
||||
mae_metric = metrics_container.metrics[0]
|
||||
self.assertEqual(mae_metric.name, 'mae')
|
||||
self.assertAlmostEqual(mae_metric.result().numpy(), 0)
|
||||
|
||||
weighted_mae_metric = metrics_container.metrics[1]
|
||||
self.assertEqual(weighted_mae_metric.name, 'mse')
|
||||
self.assertAlmostEqual(weighted_mae_metric.result().numpy(), 0)
|
||||
|
||||
def test_metrics_sample_weight(self):
|
||||
metrics_container = compile_utils.MetricsContainer(
|
||||
metrics=['mae'], weighted_metrics=['mse'])
|
||||
y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32)
|
||||
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
|
||||
sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
|
||||
|
||||
metrics_container.update_state(y_t, y_p, sample_weight=sw)
|
||||
self.assertLen(metrics_container.metrics, 2)
|
||||
|
||||
mae_metric = metrics_container.metrics[0]
|
||||
self.assertEqual(mae_metric.name, 'mae')
|
||||
self.assertAlmostEqual(mae_metric.result().numpy(), .25) # 1 / 4
|
||||
|
||||
weighted_mae_metric = metrics_container.metrics[1]
|
||||
self.assertEqual(weighted_mae_metric.name, 'mse')
|
||||
self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .5) # .5 / 1
|
||||
|
||||
def test_metrics_masking_sample_weight(self):
|
||||
metrics_container = compile_utils.MetricsContainer(
|
||||
metrics=['mae'], weighted_metrics=['mse'])
|
||||
y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32)
|
||||
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
|
||||
sw = constant_op.constant([[.3, .2], [.2, .3]], dtype=dtypes.float32)
|
||||
y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
|
||||
dtype=dtypes.float32)
|
||||
|
||||
metrics_container.update_state(y_t, y_p, sample_weight=sw)
|
||||
self.assertLen(metrics_container.metrics, 2)
|
||||
|
||||
mae_metric = metrics_container.metrics[0]
|
||||
self.assertEqual(mae_metric.name, 'mae')
|
||||
self.assertAlmostEqual(mae_metric.result().numpy(), .5) # 1 / .5
|
||||
|
||||
weighted_mae_metric = metrics_container.metrics[1]
|
||||
self.assertEqual(weighted_mae_metric.name, 'mse')
|
||||
self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .2 / .5)
|
||||
|
||||
def test_loss_class_as_metric_with_distribution(self):
|
||||
distribution = one_device_strategy.OneDeviceStrategy('/device:CPU:0')
|
||||
with distribution.scope():
|
||||
metric_container = compile_utils.MetricsContainer(
|
||||
losses_mod.MeanSquaredError())
|
||||
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
||||
metric_container.update_state(y_t, y_p)
|
||||
|
||||
self.assertLen(metric_container.metrics, 1)
|
||||
metric = metric_container.metrics[0]
|
||||
self.assertEqual(metric.name, 'mean_squared_error')
|
||||
self.assertEqual(metric.result().numpy(), 1.)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
x
Reference in New Issue
Block a user