Accept any callable as a custom Loss or custom Metric.
Previously, passing a callable class instance raised an error related to not being able to get the name for the custom loss / metric. PiperOrigin-RevId: 338846484 Change-Id: I5ffb5aaf2e19a31615cd3bb43073dc29ee4cfc33
This commit is contained in:
parent
d3d44d02c8
commit
a48834e930
@ -24,6 +24,7 @@ import six
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.keras import losses as losses_mod
|
from tensorflow.python.keras import losses as losses_mod
|
||||||
from tensorflow.python.keras import metrics as metrics_mod
|
from tensorflow.python.keras import metrics as metrics_mod
|
||||||
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import losses_utils
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -261,7 +262,9 @@ class LossesContainer(Container):
|
|||||||
|
|
||||||
loss = losses_mod.get(loss)
|
loss = losses_mod.get(loss)
|
||||||
if not isinstance(loss, losses_mod.Loss):
|
if not isinstance(loss, losses_mod.Loss):
|
||||||
loss_name = loss.__name__
|
loss_name = get_custom_object_name(loss)
|
||||||
|
if loss_name is None:
|
||||||
|
raise ValueError('Loss should be a callable, found: {}'.format(loss))
|
||||||
loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
|
loss = losses_mod.LossFunctionWrapper(loss, name=loss_name)
|
||||||
loss._allow_sum_over_batch_size = True # pylint: disable=protected-access
|
loss._allow_sum_over_batch_size = True # pylint: disable=protected-access
|
||||||
return loss
|
return loss
|
||||||
@ -466,11 +469,11 @@ class MetricsContainer(Container):
|
|||||||
if not isinstance(metric_obj, metrics_mod.Metric):
|
if not isinstance(metric_obj, metrics_mod.Metric):
|
||||||
if isinstance(metric, six.string_types):
|
if isinstance(metric, six.string_types):
|
||||||
metric_name = metric
|
metric_name = metric
|
||||||
elif hasattr(metric, 'name'):
|
|
||||||
metric_name = metric.name # TODO(omalleyt): Is this needed?
|
|
||||||
else:
|
else:
|
||||||
# function was passed.
|
metric_name = get_custom_object_name(metric)
|
||||||
metric_name = metric.__name__
|
if metric_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Metric should be a callable, found: {}'.format(metric))
|
||||||
|
|
||||||
metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
|
metric_obj = metrics_mod.MeanMetricWrapper(metric_obj, name=metric_name)
|
||||||
|
|
||||||
@ -638,3 +641,22 @@ def apply_mask(y_p, sw, mask):
|
|||||||
else:
|
else:
|
||||||
sw = mask
|
sw = mask
|
||||||
return sw
|
return sw
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_object_name(obj):
|
||||||
|
"""Returns the name to use for a custom loss or metric callable.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
obj: Custom loss of metric callable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Name to use, or `None` if the object was not recognized.
|
||||||
|
"""
|
||||||
|
if hasattr(obj, 'name'): # Accept `Loss` instance as `Metric`.
|
||||||
|
return obj.name
|
||||||
|
elif hasattr(obj, '__name__'): # Function.
|
||||||
|
return obj.__name__
|
||||||
|
elif hasattr(obj, '__class__'): # Class instance.
|
||||||
|
return generic_utils.to_snake_case(obj.__class__.__name__)
|
||||||
|
else: # Unrecognized object.
|
||||||
|
return None
|
||||||
|
|||||||
@ -338,6 +338,24 @@ class LossesContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(loss_metric.name, 'loss')
|
self.assertEqual(loss_metric.name, 'loss')
|
||||||
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
|
self.assertAlmostEqual(loss_metric.result().numpy(), .125)
|
||||||
|
|
||||||
|
def test_custom_loss_callables(self):
|
||||||
|
|
||||||
|
def custom_loss_fn(y_true, y_pred):
|
||||||
|
return math_ops.reduce_sum(y_true - y_pred)
|
||||||
|
|
||||||
|
class CustomLossClass(object):
|
||||||
|
|
||||||
|
def __call__(self, y_true, y_pred):
|
||||||
|
return math_ops.reduce_sum(y_true - y_pred)
|
||||||
|
|
||||||
|
loss_container = compile_utils.LossesContainer(
|
||||||
|
[custom_loss_fn, CustomLossClass()])
|
||||||
|
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
||||||
|
loss_container(y_t, y_p)
|
||||||
|
|
||||||
|
self.assertEqual(loss_container._losses[0].name, 'custom_loss_fn')
|
||||||
|
self.assertEqual(loss_container._losses[1].name, 'custom_loss_class')
|
||||||
|
|
||||||
|
|
||||||
class MetricsContainerTest(keras_parameterized.TestCase):
|
class MetricsContainerTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
@ -685,6 +703,24 @@ class MetricsContainerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(metric.name, 'mean_squared_error')
|
self.assertEqual(metric.name, 'mean_squared_error')
|
||||||
self.assertEqual(metric.result().numpy(), 1.)
|
self.assertEqual(metric.result().numpy(), 1.)
|
||||||
|
|
||||||
|
def test_custom_metric_callables(self):
|
||||||
|
|
||||||
|
def custom_metric_fn(y_true, y_pred):
|
||||||
|
return math_ops.reduce_sum(y_true - y_pred)
|
||||||
|
|
||||||
|
class CustomMetricClass(object):
|
||||||
|
|
||||||
|
def __call__(self, y_true, y_pred):
|
||||||
|
return math_ops.reduce_sum(y_true - y_pred)
|
||||||
|
|
||||||
|
metric_container = compile_utils.MetricsContainer(
|
||||||
|
[custom_metric_fn, CustomMetricClass()])
|
||||||
|
y_t, y_p = array_ops.ones((10, 5)), array_ops.zeros((10, 5))
|
||||||
|
metric_container.update_state(y_t, y_p)
|
||||||
|
|
||||||
|
self.assertEqual(metric_container.metrics[0].name, 'custom_metric_fn')
|
||||||
|
self.assertEqual(metric_container.metrics[1].name, 'custom_metric_class')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user