Add logcosh v2 loss and metric.

PiperOrigin-RevId: 227894520
This commit is contained in:
Pavithra Vijay 2019-01-04 12:17:54 -08:00 committed by TensorFlower Gardener
parent ee5ca473bf
commit f1d4d18a62
4 changed files with 182 additions and 0 deletions

View File

@ -540,6 +540,33 @@ class LogLoss(Loss):
return logloss(y_true, y_pred, epsilon=self.epsilon)
class Logcosh(Loss):
"""Computes the logarithm of the hyperbolic cosine of the prediction error.
logcosh = log((exp(x) + exp(-x))/2) where x is the error `y_pred` - `y_true`.
Usage:
```python
l = tf.losses.Logcosh()
loss = l([0., 1., 1.], [1., 0., 1.])
print('Loss: ', loss.numpy()) # Loss: 0.289
```
Usage with tf.keras API:
```python
model = keras.models.Model(inputs, outputs)
model.compile('sgd', loss=tf.losses.Logcosh())
```
"""
def call(self, y_true, y_pred):
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
return logcosh(y_true, y_pred)
@keras_export('keras.metrics.mean_squared_error',
'keras.metrics.mse',
'keras.metrics.MSE',

View File

@ -1094,5 +1094,86 @@ class LogLossTest(test.TestCase):
self.assertAlmostEqual(self.evaluate(loss), 0., 3)
@test_util.run_all_in_graph_and_eager_modes
class LogcoshTest(test.TestCase):
def setup(self):
y_pred = np.asarray([1, 9, 2, -5, -2, 6]).reshape((2, 3))
y_true = np.asarray([4, 8, 12, 8, 1, 3]).reshape((2, 3))
self.batch_size = 6
error = y_pred - y_true
self.expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)
self.y_pred = constant_op.constant(y_pred, dtype=dtypes.float32)
self.y_true = constant_op.constant(y_true)
def test_config(self):
logcosh_obj = keras.losses.Logcosh(
reduction=losses_impl.ReductionV2.SUM, name='logcosh_loss')
self.assertEqual(logcosh_obj.name, 'logcosh_loss')
self.assertEqual(logcosh_obj.reduction, losses_impl.ReductionV2.SUM)
def test_unweighted(self):
self.setup()
logcosh_obj = keras.losses.Logcosh()
loss = logcosh_obj(self.y_true, self.y_pred)
expected_loss = np.sum(self.expected_losses) / self.batch_size
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)
def test_scalar_weighted(self):
self.setup()
logcosh_obj = keras.losses.Logcosh()
sample_weight = 2.3
loss = logcosh_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
expected_loss = sample_weight * np.sum(
self.expected_losses) / self.batch_size
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)
# Verify we get the same output when the same input is given
loss_2 = logcosh_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), self.evaluate(loss_2), 3)
def test_sample_weighted(self):
self.setup()
logcosh_obj = keras.losses.Logcosh()
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = logcosh_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
expected_loss = np.multiply(
self.expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
expected_loss = np.sum(expected_loss) / self.batch_size
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)
def test_timestep_weighted(self):
self.setup()
logcosh_obj = keras.losses.Logcosh()
y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)
y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)
error = y_pred - y_true
expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)
sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1))
y_pred = constant_op.constant(y_pred, dtype=dtypes.float32)
y_true = constant_op.constant(y_true)
loss = logcosh_obj(
y_true,
y_pred,
sample_weight=constant_op.constant(sample_weight, shape=(2, 3)))
expected_loss = np.sum(expected_losses * sample_weight) / self.batch_size
self.assertAlmostEqual(self.evaluate(loss), expected_loss, 3)
def test_zero_weighted(self):
self.setup()
logcosh_obj = keras.losses.Logcosh()
sample_weight = 0
loss = logcosh_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), 0., 3)
if __name__ == '__main__':
test.main()

View File

@ -1654,6 +1654,37 @@ class RootMeanSquaredError(Mean):
return math_ops.sqrt(math_ops.div_no_nan(self.total, self.count))
class Logcosh(MeanMetricWrapper):
"""Computes the logarithm of the hyperbolic cosine of the prediction error.
logcosh = log((exp(x) + exp(-x))/2) where x is the error `y_pred` - `y_true`.
Usage:
```python
m = tf.keras.metrics.Logcosh()
m.update_state([0., 1., 1.], [1., 0., 1.])
print('Final result: ', m.result().numpy()) # Final result: 0.289
```
Usage with tf.keras API:
```python
model = keras.models.Model(inputs, outputs)
model.compile('sgd', metrics=[tf.keras.metrics.Logcosh()])
```
"""
def __init__(self, name='logcosh', dtype=None):
super(Logcosh, self).__init__(logcosh, name, dtype=dtype)
@classmethod
def from_config(cls, config):
if 'fn' in config:
config.pop('fn')
return super(Logcosh, cls).from_config(config)
def accuracy(y_true, y_pred):
y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
if y_true.dtype != y_pred.dtype:

View File

@ -1467,6 +1467,49 @@ class SparseTopKCategoricalAccuracyTest(test.TestCase):
self.assertEqual(0.5, self.evaluate(result)) # only 1 sample matches.
@test_util.run_all_in_graph_and_eager_modes
class LogcoshTest(test.TestCase):
def setup(self):
y_pred = np.asarray([1, 9, 2, -5, -2, 6]).reshape((2, 3))
y_true = np.asarray([4, 8, 12, 8, 1, 3]).reshape((2, 3))
self.batch_size = 6
error = y_pred - y_true
self.expected_results = np.log((np.exp(error) + np.exp(-error)) / 2)
self.y_pred = constant_op.constant(y_pred, dtype=dtypes.float32)
self.y_true = constant_op.constant(y_true)
def test_config(self):
logcosh_obj = metrics.Logcosh(name='logcosh', dtype=dtypes.int32)
self.assertEqual(logcosh_obj.name, 'logcosh')
self.assertEqual(logcosh_obj._dtype, dtypes.int32)
def test_unweighted(self):
self.setup()
logcosh_obj = metrics.Logcosh()
self.evaluate(variables.variables_initializer(logcosh_obj.variables))
update_op = logcosh_obj.update_state(self.y_true, self.y_pred)
self.evaluate(update_op)
result = logcosh_obj.result()
expected_result = np.sum(self.expected_results) / self.batch_size
self.assertAllClose(result, expected_result, atol=1e-3)
def test_weighted(self):
self.setup()
logcosh_obj = metrics.Logcosh()
self.evaluate(variables.variables_initializer(logcosh_obj.variables))
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
result = logcosh_obj(self.y_true, self.y_pred, sample_weight=sample_weight)
sample_weight = np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3))
expected_result = np.multiply(self.expected_results, sample_weight)
expected_result = np.sum(expected_result) / np.sum(sample_weight)
self.assertAllClose(self.evaluate(result), expected_result, atol=1e-3)
def _get_model(compile_metrics):
model_layers = [
layers.Dense(3, activation='relu', kernel_initializer='ones'),