Accept any callable as a custom Loss or custom Metric.

Previously, passing a callable class instance raised an error related to not
being able to get the name for the custom loss / metric.

PiperOrigin-RevId: 338846484
Change-Id: I5ffb5aaf2e19a31615cd3bb43073dc29ee4cfc33
This commit is contained in:
Thomas O'Malley 2020-10-24 11:44:23 -07:00 committed by TensorFlower Gardener
parent d3d44d02c8
commit a48834e930
2 changed files with 63 additions and 5 deletions

View File

@ -24,6 +24,7 @@ import six
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.keras import losses as losses_mod
from tensorflow.python.keras import metrics as metrics_mod
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -261,7 +262,9 @@ class LossesContainer(Container):
loss = losses_mod.get(loss)
if not isinstance(loss, losses_mod.Loss):
loss_name = loss.__name__
loss_name = get_custom_object_name(loss)
if loss_name is None:
raise ValueError('Loss should be a callable, found: {}'.format(loss))
loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
loss._allow_sum_over_batch_size = True # pylint: disable=protected-access
return loss
@ -466,11 +469,11 @@ class MetricsContainer(Container):
if not isinstance(metric_obj, metrics_mod.Metric):
if isinstance(metric, six.string_types):
metric_name = metric
elif hasattr(metric, 'name'):
metric_name = metric.name # TODO(omalleyt): Is this needed?
else:
# function was passed.
metric_name = metric.__name__
metric_name = get_custom_object_name(metric)
if metric_name is None:
raise ValueError(
'Metric should be a callable, found: {}'.format(metric))
metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
@ -638,3 +641,22 @@ def apply_mask(y_p, sw, mask):
else:
sw = mask
return sw
def get_custom_object_name(obj):
"""Returns the name to use for a custom loss or metric callable.
Arguments:
obj: Custom loss of metric callable
Returns:
Name to use, or `None` if the object was not recognized.
"""
if hasattr(obj, 'name'): # Accept `Loss` instance as `Metric`.
return obj.name
elif hasattr(obj, '__name__'): # Function.
return obj.__name__
elif hasattr(obj, '__class__'): # Class instance.
return generic_utils.to_snake_case(obj.__class__.__name__)
else: # Unrecognized object.
return None

View File

@ -338,6 +338,24 @@ class LossesContainerTest(keras_parameterized.TestCase):
self.assertEqual(loss_metric.name, 'loss')
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
def test_custom_loss_callables(self):
def custom_loss_fn(y_true, y_pred):
return math_ops.reduce_sum(y_true - y_pred)
class CustomLossClass(object):
def __call__(self, y_true, y_pred):
return math_ops.reduce_sum(y_true - y_pred)
loss_container = compile_utils.LossesContainer(
[custom_loss_fn, CustomLossClass()])
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
loss_container(y_t, y_p)
self.assertEqual(loss_container._losses[0].name, 'custom_loss_fn')
self.assertEqual(loss_container._losses[1].name, 'custom_loss_class')
class MetricsContainerTest(keras_parameterized.TestCase):
@ -685,6 +703,24 @@ class MetricsContainerTest(keras_parameterized.TestCase):
self.assertEqual(metric.name, 'mean_squared_error')
self.assertEqual(metric.result().numpy(), 1.)
def test_custom_metric_callables(self):
def custom_metric_fn(y_true, y_pred):
return math_ops.reduce_sum(y_true - y_pred)
class CustomMetricClass(object):
def __call__(self, y_true, y_pred):
return math_ops.reduce_sum(y_true - y_pred)
metric_container = compile_utils.MetricsContainer(
[custom_metric_fn, CustomMetricClass()])
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
metric_container.update_state(y_t, y_p)
self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
if __name__ == '__main__':
ops.enable_eager_execution()