Factoring out a create_loss() function for all Heads.

PiperOrigin-RevId: 164989700
This commit is contained in:
A. Unique TensorFlower 2017-08-11 09:55:10 -07:00 committed by TensorFlower Gardener
parent 3c482c66b5
commit 8d75705ffa
2 changed files with 571 additions and 121 deletions

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import abc
import collections
import six
from tensorflow.python.estimator import model_fn
@ -46,6 +47,10 @@ from tensorflow.python.summary import summary
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
LossAndLabels = collections.namedtuple('LossAndLabels',
['unweighted_loss', 'processed_labels'])
class _Head(object):
"""Interface for the head/top of a model.
@ -114,6 +119,28 @@ class _Head(object):
"""
raise NotImplementedError('Calling an abstract method.')
@abc.abstractmethod
def create_loss(self, features, mode, logits, labels):
"""Returns a loss Tensor from provided logits.
This function is designed to be used by framework developers. Almost all
users should use create_estimator_spec(), which calls this internally.
`mode` and `features` are most likely not used, but some Head
implementations may require them.
Args:
features: Input `dict` of `Tensor` objects.
mode: Estimator's `ModeKeys`.
logits: logits `Tensor` to be used for loss construction.
labels: Labels `Tensor`.
Returns:
A LossAndLabels that contains the `Tensor` representing the loss and
possibly processed labels (e.g. vocabulary lookup, shape manipulation,
etc.), to be extendable in the future.
"""
raise NotImplementedError('Calling an abstract method.')
@abc.abstractmethod
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
@ -363,6 +390,17 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
name='class_id_lookup').lookup(labels)
return _assert_range(label_ids, self._n_classes)
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
del mode, features # Unused for this head.
label_ids = self._label_ids(_check_labels(_maybe_expand_dim(labels), 1))
unweighted_loss = losses.sparse_softmax_cross_entropy(
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
# Restore the squeezed dim, so unweighted_loss matches the weights shape.
return LossAndLabels(
unweighted_loss=array_ops.expand_dims(unweighted_loss, axis=(1,)),
processed_labels=label_ids)
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `Head`."""
@ -412,12 +450,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
})
# Eval.
label_ids = self._label_ids(_check_labels(_maybe_expand_dim(labels), 1))
unweighted_loss = losses.sparse_softmax_cross_entropy(
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
# Restore the squeezed dim, so unweighted_loss matches the weights shape.
unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=(1,))
unweighted_loss, label_ids = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
weights = _weights(features, self._weight_column)
training_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
@ -573,6 +607,21 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
threshold=threshold, name=recall_key)
return metric_ops
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
del mode, features # Unused for this head.
labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
if self._label_vocabulary is not None:
labels = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels)
labels = math_ops.to_float(labels)
labels = _assert_range(labels, 2)
return LossAndLabels(
unweighted_loss=nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits),
processed_labels=labels)
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `Head`."""
@ -624,15 +673,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
})
# Eval.
labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
if self._label_vocabulary is not None:
labels = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels)
labels = math_ops.to_float(labels)
labels = _assert_range(labels, 2)
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits, name='loss')
unweighted_loss, processed_labels = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
weights = _weights(features, self._weight_column)
training_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
@ -642,7 +684,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
predictions=predictions,
loss=training_loss,
eval_metric_ops=self._eval_metric_ops(
labels=labels,
labels=processed_labels,
logits=logits,
logistic=logistic,
scores=scores,
@ -701,6 +743,16 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
def logits_dimension(self):
return self._logits_dimension
def create_loss(self, features, mode, logits, labels):
"""See `Head`."""
del mode, features # Unused for this head.
labels = _check_labels(
_maybe_expand_dim(math_ops.to_float(labels)), self._logits_dimension)
return LossAndLabels(
unweighted_loss=losses.mean_squared_error(
labels=labels, predictions=logits, reduction=losses.Reduction.NONE),
processed_labels=labels)
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
"""See `Head`."""
@ -715,10 +767,8 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
export_outputs={'': export_output.RegressionOutput(value=logits)})
# Eval.
labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)),
self._logits_dimension)
unweighted_loss = losses.mean_squared_error(
labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
unweighted_loss, _ = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
weights = _weights(features, self._weight_column)
training_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)

File diff suppressed because it is too large Load Diff