Accept any callable as a custom Loss or custom Metric.
Previously, passing a callable class instance raised an error related to not being able to get the name for the custom loss / metric. PiperOrigin-RevId: 338846484 Change-Id: I5ffb5aaf2e19a31615cd3bb43073dc29ee4cfc33
This commit is contained in:
parent
d3d44d02c8
commit
a48834e930
@ -24,6 +24,7 @@ import six
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.keras import losses as losses_mod
|
||||
from tensorflow.python.keras import metrics as metrics_mod
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -261,7 +262,9 @@ class LossesContainer(Container):
|
||||
|
||||
loss = losses_mod.get(loss)
|
||||
if not isinstance(loss, losses_mod.Loss):
|
||||
loss_name = loss.__name__
|
||||
loss_name = get_custom_object_name(loss)
|
||||
if loss_name is None:
|
||||
raise ValueError('Loss should be a callable, found: {}'.format(loss))
|
||||
loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
|
||||
loss._allow_sum_over_batch_size = True # pylint: disable=protected-access
|
||||
return loss
|
||||
@ -466,11 +469,11 @@ class MetricsContainer(Container):
|
||||
if not isinstance(metric_obj, metrics_mod.Metric):
|
||||
if isinstance(metric, six.string_types):
|
||||
metric_name = metric
|
||||
elif hasattr(metric, 'name'):
|
||||
metric_name = metric.name # TODO(omalleyt): Is this needed?
|
||||
else:
|
||||
# function was passed.
|
||||
metric_name = metric.__name__
|
||||
metric_name = get_custom_object_name(metric)
|
||||
if metric_name is None:
|
||||
raise ValueError(
|
||||
'Metric should be a callable, found: {}'.format(metric))
|
||||
|
||||
metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
|
||||
|
||||
@ -638,3 +641,22 @@ def apply_mask(y_p, sw, mask):
|
||||
else:
|
||||
sw = mask
|
||||
return sw
|
||||
|
||||
|
||||
def get_custom_object_name(obj):
|
||||
"""Returns the name to use for a custom loss or metric callable.
|
||||
|
||||
Arguments:
|
||||
obj: Custom loss of metric callable
|
||||
|
||||
Returns:
|
||||
Name to use, or `None` if the object was not recognized.
|
||||
"""
|
||||
if hasattr(obj, 'name'): # Accept `Loss` instance as `Metric`.
|
||||
return obj.name
|
||||
elif hasattr(obj, '__name__'): # Function.
|
||||
return obj.__name__
|
||||
elif hasattr(obj, '__class__'): # Class instance.
|
||||
return generic_utils.to_snake_case(obj.__class__.__name__)
|
||||
else: # Unrecognized object.
|
||||
return None
|
||||
|
||||
@ -338,6 +338,24 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(loss_metric.name, 'loss')
|
||||
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
|
||||
|
||||
def test_custom_loss_callables(self):
|
||||
|
||||
def custom_loss_fn(y_true, y_pred):
|
||||
return math_ops.reduce_sum(y_true - y_pred)
|
||||
|
||||
class CustomLossClass(object):
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
return math_ops.reduce_sum(y_true - y_pred)
|
||||
|
||||
loss_container = compile_utils.LossesContainer(
|
||||
[custom_loss_fn, CustomLossClass()])
|
||||
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
||||
loss_container(y_t, y_p)
|
||||
|
||||
self.assertEqual(loss_container._losses[0].name, 'custom_loss_fn')
|
||||
self.assertEqual(loss_container._losses[1].name, 'custom_loss_class')
|
||||
|
||||
|
||||
class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
|
||||
@ -685,6 +703,24 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(metric.name, 'mean_squared_error')
|
||||
self.assertEqual(metric.result().numpy(), 1.)
|
||||
|
||||
def test_custom_metric_callables(self):
|
||||
|
||||
def custom_metric_fn(y_true, y_pred):
|
||||
return math_ops.reduce_sum(y_true - y_pred)
|
||||
|
||||
class CustomMetricClass(object):
|
||||
|
||||
def __call__(self, y_true, y_pred):
|
||||
return math_ops.reduce_sum(y_true - y_pred)
|
||||
|
||||
metric_container = compile_utils.MetricsContainer(
|
||||
[custom_metric_fn, CustomMetricClass()])
|
||||
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
||||
metric_container.update_state(y_t, y_p)
|
||||
|
||||
self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
|
||||
self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user