Add loss_only_head to hold additional loss terms for multi_head setup

PiperOrigin-RevId: 157875934
This commit is contained in:
A. Unique TensorFlower 2017-06-02 14:29:17 -07:00 committed by TensorFlower Gardener
parent 7cdcd0cca2
commit 9e25c68ad1
3 changed files with 116 additions and 3 deletions
tensorflow/contrib/learn/python/learn/estimators

View File

@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
from tensorflow.contrib.learn.python.learn.estimators.head import Head
from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head

View File

@ -429,6 +429,23 @@ def multi_label_head(n_classes,
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
def loss_only_head(loss_fn, head_name=None):
"""Creates a Head that contains only loss terms.
Loss only head holds additional loss terms to be added to other heads and
usually represents additional regularization terms in the objective function.
Args:
loss_fn: a function that takes no argument and returns a list of
scalar tensors.
head_name: a name for for the head.
Returns:
An instance of `Head` to hold the additional losses.
"""
return _LossOnlyHead(loss_fn, head_name=head_name)
def multi_head(heads, loss_weights=None):
"""Creates a MultiHead stemming from same logits/hidden layer.
@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead):
return metrics
class _LossOnlyHead(Head):
"""`Head` implementation for additional loss terms.
This class only holds loss terms unrelated to any other heads (labels),
e.g. regularization.
Common usage:
This is oftem combine with other heads in a multi head setup.
```python
head = multi_head([
head1, head2, loss_only_head('regularizer', regularizer)])
```
"""
def __init__(self, loss_fn, head_name=None):
self._loss_fn = loss_fn
self.head_name = head_name or "loss_only_head"
@property
def logits_dimension(self):
return 0
def create_model_fn_ops(self,
features,
mode,
labels=None,
train_op_fn=None,
logits=None,
logits_input=None,
scope=None):
"""See `_Head.create_model_fn_ops`.
Args:
features: Not been used.
mode: Estimator's `ModeKeys`.
labels: Labels `Tensor`, or `dict` of same.
train_op_fn: Function that takes a scalar loss and returns an op to
optimize with the loss.
logits: Not been used.
logits_input: Not been used.
scope: Optional scope for variable_scope. If provided, will be passed to
all heads. Most users will want to set this to `None`, so each head
constructs a separate variable_scope according to its `head_name`.
Returns:
A `ModelFnOps` object.
Raises:
ValueError: if `mode` is not recognition.
"""
_check_mode_valid(mode)
loss = None
train_op = None
if mode != model_fn.ModeKeys.INFER:
with variable_scope.variable_scope(scope, default_name=self.head_name):
loss = self._loss_fn()
if isinstance(loss, list):
loss = math_ops.add_n(loss)
logging_ops.scalar_summary(
_summary_key(self.head_name, mkey.LOSS), loss)
if mode == model_fn.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError("train_op_fn can not be None in TRAIN mode")
with ops.name_scope(None, "train_op", (loss,)):
train_op = train_op_fn(loss)
return model_fn.ModelFnOps(
mode=mode,
loss=loss,
train_op=train_op,
predictions={},
eval_metric_ops={})
class _MultiHead(Head):
"""`Head` implementation for multi objective learning.
@ -1525,7 +1616,10 @@ class _MultiHead(Head):
if isinstance(logits, dict):
head_logits_pairs = []
for head in self._heads:
head_logits_pairs.append((head, logits[head.head_name]))
if isinstance(head, _LossOnlyHead):
head_logits_pairs.append((head, None))
else:
head_logits_pairs.append((head, logits[head.head_name]))
else:
# Split logits for each head.
head_logits_pairs = zip(self._heads, self._split_logits(logits))
@ -1606,6 +1700,8 @@ class _MultiHead(Head):
predictions = {}
output_alternatives = {}
for head, m in zip(self._heads, all_model_fn_ops):
if isinstance(head, _LossOnlyHead):
continue
head_name = head.head_name
output_alternatives[head_name] = m.output_alternatives[head_name]
for k, v in m.predictions.items():

View File

@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase):
}, model_fn_ops)
class LossOnlyHead(test.TestCase):
def testNoPredictionsAndNoMetrics(self):
head = head_lib.loss_only_head(lambda: 1, head_name="const")
model_fn_ops = head.create_model_fn_ops(
features={},
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=head_lib.no_op_train_fn)
self.assertDictEqual(model_fn_ops.predictions, {})
self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
self.assertIsNotNone(model_fn_ops.loss)
with session.Session() as sess:
self.assertEqual(1, sess.run(model_fn_ops.loss))
class MultiHeadTest(test.TestCase):
def testInvalidHeads(self):
@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase):
n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2")
head = head_lib.multi_head((head1, head2))
head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
head = head_lib.multi_head((head1, head2, head3))
labels = {
"label1": (1,),
"label2": (1,)
@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase):
self.assertIsNone(model_fn_ops.output_alternatives)
with session.Session() as sess:
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
def testTrain_withHeadWeights(self):
head1 = head_lib.multi_class_head(