Fix issue - 1. tf.keras.Model.compile metrics do not respect masking since TF 2.2

2. Prevent passing a loss as metric into model.compile

PiperOrigin-RevId: 314962578
Change-Id: If9c69173ba2ca6d475c2b3d06bdaffad72e29f7b
This commit is contained in:
Pavithra Vijay 2020-06-05 11:22:25 -07:00 committed by TensorFlower Gardener
parent 1c7fb569b4
commit 1c144fdb67
2 changed files with 132 additions and 7 deletions

View File

@ -200,8 +200,7 @@ class LossesContainer(Container):
continue
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
sw = apply_mask(y_p, sw)
sw = apply_mask(y_p, sw, get_mask(y_p))
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
loss_metric_value = loss_value
@ -401,12 +400,13 @@ class MetricsContainer(Container):
continue
y_t, y_p, sw = match_dtype_and_rank(y_t, y_p, sw)
sw = apply_mask(y_p, sw)
mask = get_mask(y_p)
sw = apply_mask(y_p, sw, mask)
for metric_obj in metric_objs:
if metric_obj is None:
continue
metric_obj.update_state(y_t, y_p)
metric_obj.update_state(y_t, y_p, sample_weight=mask)
for weighted_metric_obj in weighted_metric_objs:
if weighted_metric_obj is None:
@ -461,6 +461,9 @@ class MetricsContainer(Container):
else:
metric_obj = metrics_mod.categorical_crossentropy
if isinstance(metric_obj, losses_mod.Loss):
metric_obj._allow_sum_over_batch_size = True # pylint: disable=protected-access
if not isinstance(metric_obj, metrics_mod.Metric):
if isinstance(metric, six.string_types):
metric_name = metric
@ -620,10 +623,13 @@ def match_dtype_and_rank(y_t, y_p, sw):
return y_t, y_p, sw
def apply_mask(y_p, sw):
def get_mask(y_p):
"""Returns Keras mask from tensor."""
return getattr(y_p, '_keras_mask', None)
def apply_mask(y_p, sw, mask):
"""Applies any mask on predictions to sample weights."""
# Handle Keras mask on outputs.
mask = getattr(y_p, '_keras_mask', None)
if mask is not None:
mask = math_ops.cast(mask, y_p.dtype)
if sw is not None:

View File

@ -18,11 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import one_device_strategy
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import losses as losses_mod
from tensorflow.python.keras import metrics as metrics_mod
from tensorflow.python.keras.engine import compile_utils
from tensorflow.python.ops import array_ops
@ -289,6 +291,53 @@ class LossesContainerTest(keras_parameterized.TestCase):
total_loss = loss_container(y_t, y_p)
self.assertEqual(total_loss.dtype, dtypes.float64)
def test_loss_masking(self):
loss_container = compile_utils.LossesContainer('mae')
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
dtype=dtypes.float32)
total_loss = loss_container(y_t, y_p)
self.assertAlmostEqual(total_loss.numpy(), .25) # sum over batch size
self.assertLen(loss_container.metrics, 1)
loss_metric = loss_container.metrics[0]
self.assertEqual(loss_metric.name, 'loss')
self.assertAlmostEqual(loss_metric.result().numpy(), .25)
def test_loss_sample_weight(self):
loss_container = compile_utils.LossesContainer('mae')
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
total_loss = loss_container(y_t, y_p, sample_weight=sw)
# (0 * .2 + 0 * .3 + 1 * .5 + 1 * 0) / 4
self.assertAlmostEqual(total_loss.numpy(), .125)
self.assertLen(loss_container.metrics, 1)
loss_metric = loss_container.metrics[0]
self.assertEqual(loss_metric.name, 'loss')
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
def test_loss_masking_sample_weight(self):
loss_container = compile_utils.LossesContainer('mae')
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
dtype=dtypes.float32)
total_loss = loss_container(y_t, y_p, sample_weight=sw)
# (0 * .2 + 1 * .5) / 4
self.assertAlmostEqual(total_loss.numpy(), .125) # sum over batch size
self.assertLen(loss_container.metrics, 1)
loss_metric = loss_container.metrics[0]
self.assertEqual(loss_metric.name, 'loss')
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
class MetricsContainerTest(keras_parameterized.TestCase):
@ -566,6 +615,76 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(mse_metric.name, 'output3_mse')
self.assertEqual(mse_metric.result().numpy(), 4.)
def test_metrics_masking(self):
metrics_container = compile_utils.MetricsContainer(
metrics=['mae'], weighted_metrics=['mse'])
y_p = constant_op.constant([[[1], [1]], [[0], [0]]], dtype=dtypes.float32)
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
y_p._keras_mask = constant_op.constant([[1, 1], [0, 0]],
dtype=dtypes.float32)
metrics_container.update_state(y_t, y_p)
self.assertLen(metrics_container.metrics, 2)
mae_metric = metrics_container.metrics[0]
self.assertEqual(mae_metric.name, 'mae')
self.assertAlmostEqual(mae_metric.result().numpy(), 0)
weighted_mae_metric = metrics_container.metrics[1]
self.assertEqual(weighted_mae_metric.name, 'mse')
self.assertAlmostEqual(weighted_mae_metric.result().numpy(), 0)
def test_metrics_sample_weight(self):
metrics_container = compile_utils.MetricsContainer(
metrics=['mae'], weighted_metrics=['mse'])
y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32)
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
sw = constant_op.constant([[.2, .3], [.5, 0]], dtype=dtypes.float32)
metrics_container.update_state(y_t, y_p, sample_weight=sw)
self.assertLen(metrics_container.metrics, 2)
mae_metric = metrics_container.metrics[0]
self.assertEqual(mae_metric.name, 'mae')
self.assertAlmostEqual(mae_metric.result().numpy(), .25) # 1 / 4
weighted_mae_metric = metrics_container.metrics[1]
self.assertEqual(weighted_mae_metric.name, 'mse')
self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .5) # .5 / 1
def test_metrics_masking_sample_weight(self):
metrics_container = compile_utils.MetricsContainer(
metrics=['mae'], weighted_metrics=['mse'])
y_p = constant_op.constant([[[1], [1]], [[0], [1]]], dtype=dtypes.float32)
y_t = constant_op.constant([[[1], [1]], [[1], [1]]], dtype=dtypes.float32)
sw = constant_op.constant([[.3, .2], [.2, .3]], dtype=dtypes.float32)
y_p._keras_mask = constant_op.constant([[1, 0], [1, 0]],
dtype=dtypes.float32)
metrics_container.update_state(y_t, y_p, sample_weight=sw)
self.assertLen(metrics_container.metrics, 2)
mae_metric = metrics_container.metrics[0]
self.assertEqual(mae_metric.name, 'mae')
self.assertAlmostEqual(mae_metric.result().numpy(), .5) # 1 / .5
weighted_mae_metric = metrics_container.metrics[1]
self.assertEqual(weighted_mae_metric.name, 'mse')
self.assertAlmostEqual(weighted_mae_metric.result().numpy(), .2 / .5)
def test_loss_class_as_metric_with_distribution(self):
distribution = one_device_strategy.OneDeviceStrategy('/device:CPU:0')
with distribution.scope():
metric_container = compile_utils.MetricsContainer(
losses_mod.MeanSquaredError())
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
metric_container.update_state(y_t, y_p)
self.assertLen(metric_container.metrics, 1)
metric = metric_container.metrics[0]
self.assertEqual(metric.name, 'mean_squared_error')
self.assertEqual(metric.result().numpy(), 1.)
if __name__ == '__main__':
ops.enable_eager_execution()