Factoring out a create_loss() function for all Heads.
PiperOrigin-RevId: 164989700
This commit is contained in:
parent
3c482c66b5
commit
8d75705ffa
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import collections
|
||||
import six
|
||||
|
||||
from tensorflow.python.estimator import model_fn
|
||||
@ -46,6 +47,10 @@ from tensorflow.python.summary import summary
|
||||
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
|
||||
|
||||
LossAndLabels = collections.namedtuple('LossAndLabels',
|
||||
['unweighted_loss', 'processed_labels'])
|
||||
|
||||
|
||||
class _Head(object):
|
||||
"""Interface for the head/top of a model.
|
||||
|
||||
@ -114,6 +119,28 @@ class _Head(object):
|
||||
"""
|
||||
raise NotImplementedError('Calling an abstract method.')
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_loss(self, features, mode, logits, labels):
|
||||
"""Returns a loss Tensor from provided logits.
|
||||
|
||||
This function is designed to be used by framework developers. Almost all
|
||||
users should use create_estimator_spec(), which calls this internally.
|
||||
`mode` and `features` are most likely not used, but some Head
|
||||
implementations may require them.
|
||||
|
||||
Args:
|
||||
features: Input `dict` of `Tensor` objects.
|
||||
mode: Estimator's `ModeKeys`.
|
||||
logits: logits `Tensor` to be used for loss construction.
|
||||
labels: Labels `Tensor`.
|
||||
|
||||
Returns:
|
||||
A LossAndLabels that contains the `Tensor` representing the loss and
|
||||
possibly processed labels (e.g. vocabulary lookup, shape manipulation,
|
||||
etc.), to be extendable in the future.
|
||||
"""
|
||||
raise NotImplementedError('Calling an abstract method.')
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_estimator_spec(
|
||||
self, features, mode, logits, labels=None, train_op_fn=None):
|
||||
@ -363,6 +390,17 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
|
||||
name='class_id_lookup').lookup(labels)
|
||||
return _assert_range(label_ids, self._n_classes)
|
||||
|
||||
def create_loss(self, features, mode, logits, labels):
|
||||
"""See `Head`."""
|
||||
del mode, features # Unused for this head.
|
||||
label_ids = self._label_ids(_check_labels(_maybe_expand_dim(labels), 1))
|
||||
unweighted_loss = losses.sparse_softmax_cross_entropy(
|
||||
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
|
||||
# Restore the squeezed dim, so unweighted_loss matches the weights shape.
|
||||
return LossAndLabels(
|
||||
unweighted_loss=array_ops.expand_dims(unweighted_loss, axis=(1,)),
|
||||
processed_labels=label_ids)
|
||||
|
||||
def create_estimator_spec(
|
||||
self, features, mode, logits, labels=None, train_op_fn=None):
|
||||
"""See `Head`."""
|
||||
@ -412,12 +450,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
|
||||
})
|
||||
|
||||
# Eval.
|
||||
label_ids = self._label_ids(_check_labels(_maybe_expand_dim(labels), 1))
|
||||
|
||||
unweighted_loss = losses.sparse_softmax_cross_entropy(
|
||||
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
|
||||
# Restore the squeezed dim, so unweighted_loss matches the weights shape.
|
||||
unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=(1,))
|
||||
unweighted_loss, label_ids = self.create_loss(
|
||||
features=features, mode=mode, logits=logits, labels=labels)
|
||||
weights = _weights(features, self._weight_column)
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
@ -573,6 +607,21 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
|
||||
threshold=threshold, name=recall_key)
|
||||
return metric_ops
|
||||
|
||||
def create_loss(self, features, mode, logits, labels):
|
||||
"""See `Head`."""
|
||||
del mode, features # Unused for this head.
|
||||
labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
|
||||
if self._label_vocabulary is not None:
|
||||
labels = lookup_ops.index_table_from_tensor(
|
||||
vocabulary_list=tuple(self._label_vocabulary),
|
||||
name='class_id_lookup').lookup(labels)
|
||||
labels = math_ops.to_float(labels)
|
||||
labels = _assert_range(labels, 2)
|
||||
return LossAndLabels(
|
||||
unweighted_loss=nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits),
|
||||
processed_labels=labels)
|
||||
|
||||
def create_estimator_spec(
|
||||
self, features, mode, logits, labels=None, train_op_fn=None):
|
||||
"""See `Head`."""
|
||||
@ -624,15 +673,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
|
||||
})
|
||||
|
||||
# Eval.
|
||||
labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
|
||||
if self._label_vocabulary is not None:
|
||||
labels = lookup_ops.index_table_from_tensor(
|
||||
vocabulary_list=tuple(self._label_vocabulary),
|
||||
name='class_id_lookup').lookup(labels)
|
||||
labels = math_ops.to_float(labels)
|
||||
labels = _assert_range(labels, 2)
|
||||
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits, name='loss')
|
||||
unweighted_loss, processed_labels = self.create_loss(
|
||||
features=features, mode=mode, logits=logits, labels=labels)
|
||||
weights = _weights(features, self._weight_column)
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
@ -642,7 +684,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
|
||||
predictions=predictions,
|
||||
loss=training_loss,
|
||||
eval_metric_ops=self._eval_metric_ops(
|
||||
labels=labels,
|
||||
labels=processed_labels,
|
||||
logits=logits,
|
||||
logistic=logistic,
|
||||
scores=scores,
|
||||
@ -701,6 +743,16 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
|
||||
def logits_dimension(self):
|
||||
return self._logits_dimension
|
||||
|
||||
def create_loss(self, features, mode, logits, labels):
|
||||
"""See `Head`."""
|
||||
del mode, features # Unused for this head.
|
||||
labels = _check_labels(
|
||||
_maybe_expand_dim(math_ops.to_float(labels)), self._logits_dimension)
|
||||
return LossAndLabels(
|
||||
unweighted_loss=losses.mean_squared_error(
|
||||
labels=labels, predictions=logits, reduction=losses.Reduction.NONE),
|
||||
processed_labels=labels)
|
||||
|
||||
def create_estimator_spec(
|
||||
self, features, mode, logits, labels=None, train_op_fn=None):
|
||||
"""See `Head`."""
|
||||
@ -715,10 +767,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
|
||||
export_outputs={'': export_output.RegressionOutput(value=logits)})
|
||||
|
||||
# Eval.
|
||||
labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)),
|
||||
self._logits_dimension)
|
||||
unweighted_loss = losses.mean_squared_error(
|
||||
labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
|
||||
unweighted_loss, _ = self.create_loss(
|
||||
features=features, mode=mode, logits=logits, labels=labels)
|
||||
weights = _weights(features, self._weight_column)
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user