Add loss_only_head to hold additional loss terms for multi_head setup
PiperOrigin-RevId: 157875934
This commit is contained in:
parent
7cdcd0cca2
commit
9e25c68ad1
tensorflow/contrib/learn/python/learn/estimators
@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
||||
|
@ -429,6 +429,23 @@ def multi_label_head(n_classes,
|
||||
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
||||
|
||||
|
||||
def loss_only_head(loss_fn, head_name=None):
|
||||
"""Creates a Head that contains only loss terms.
|
||||
|
||||
Loss only head holds additional loss terms to be added to other heads and
|
||||
usually represents additional regularization terms in the objective function.
|
||||
|
||||
Args:
|
||||
loss_fn: a function that takes no argument and returns a list of
|
||||
scalar tensors.
|
||||
head_name: a name for for the head.
|
||||
|
||||
Returns:
|
||||
An instance of `Head` to hold the additional losses.
|
||||
"""
|
||||
return _LossOnlyHead(loss_fn, head_name=head_name)
|
||||
|
||||
|
||||
def multi_head(heads, loss_weights=None):
|
||||
"""Creates a MultiHead stemming from same logits/hidden layer.
|
||||
|
||||
@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead):
|
||||
return metrics
|
||||
|
||||
|
||||
class _LossOnlyHead(Head):
|
||||
"""`Head` implementation for additional loss terms.
|
||||
|
||||
This class only holds loss terms unrelated to any other heads (labels),
|
||||
e.g. regularization.
|
||||
|
||||
Common usage:
|
||||
This is oftem combine with other heads in a multi head setup.
|
||||
```python
|
||||
head = multi_head([
|
||||
head1, head2, loss_only_head('regularizer', regularizer)])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, loss_fn, head_name=None):
|
||||
self._loss_fn = loss_fn
|
||||
self.head_name = head_name or "loss_only_head"
|
||||
|
||||
@property
|
||||
def logits_dimension(self):
|
||||
return 0
|
||||
|
||||
def create_model_fn_ops(self,
|
||||
features,
|
||||
mode,
|
||||
labels=None,
|
||||
train_op_fn=None,
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""See `_Head.create_model_fn_ops`.
|
||||
|
||||
Args:
|
||||
features: Not been used.
|
||||
mode: Estimator's `ModeKeys`.
|
||||
labels: Labels `Tensor`, or `dict` of same.
|
||||
train_op_fn: Function that takes a scalar loss and returns an op to
|
||||
optimize with the loss.
|
||||
logits: Not been used.
|
||||
logits_input: Not been used.
|
||||
scope: Optional scope for variable_scope. If provided, will be passed to
|
||||
all heads. Most users will want to set this to `None`, so each head
|
||||
constructs a separate variable_scope according to its `head_name`.
|
||||
|
||||
Returns:
|
||||
A `ModelFnOps` object.
|
||||
|
||||
Raises:
|
||||
ValueError: if `mode` is not recognition.
|
||||
"""
|
||||
_check_mode_valid(mode)
|
||||
loss = None
|
||||
train_op = None
|
||||
if mode != model_fn.ModeKeys.INFER:
|
||||
with variable_scope.variable_scope(scope, default_name=self.head_name):
|
||||
loss = self._loss_fn()
|
||||
if isinstance(loss, list):
|
||||
loss = math_ops.add_n(loss)
|
||||
logging_ops.scalar_summary(
|
||||
_summary_key(self.head_name, mkey.LOSS), loss)
|
||||
if mode == model_fn.ModeKeys.TRAIN:
|
||||
if train_op_fn is None:
|
||||
raise ValueError("train_op_fn can not be None in TRAIN mode")
|
||||
with ops.name_scope(None, "train_op", (loss,)):
|
||||
train_op = train_op_fn(loss)
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
mode=mode,
|
||||
loss=loss,
|
||||
train_op=train_op,
|
||||
predictions={},
|
||||
eval_metric_ops={})
|
||||
|
||||
|
||||
class _MultiHead(Head):
|
||||
"""`Head` implementation for multi objective learning.
|
||||
|
||||
@ -1525,7 +1616,10 @@ class _MultiHead(Head):
|
||||
if isinstance(logits, dict):
|
||||
head_logits_pairs = []
|
||||
for head in self._heads:
|
||||
head_logits_pairs.append((head, logits[head.head_name]))
|
||||
if isinstance(head, _LossOnlyHead):
|
||||
head_logits_pairs.append((head, None))
|
||||
else:
|
||||
head_logits_pairs.append((head, logits[head.head_name]))
|
||||
else:
|
||||
# Split logits for each head.
|
||||
head_logits_pairs = zip(self._heads, self._split_logits(logits))
|
||||
@ -1606,6 +1700,8 @@ class _MultiHead(Head):
|
||||
predictions = {}
|
||||
output_alternatives = {}
|
||||
for head, m in zip(self._heads, all_model_fn_ops):
|
||||
if isinstance(head, _LossOnlyHead):
|
||||
continue
|
||||
head_name = head.head_name
|
||||
output_alternatives[head_name] = m.output_alternatives[head_name]
|
||||
for k, v in m.predictions.items():
|
||||
|
@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
|
||||
class LossOnlyHead(test.TestCase):
|
||||
|
||||
def testNoPredictionsAndNoMetrics(self):
|
||||
head = head_lib.loss_only_head(lambda: 1, head_name="const")
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features={},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
train_op_fn=head_lib.no_op_train_fn)
|
||||
self.assertDictEqual(model_fn_ops.predictions, {})
|
||||
self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
|
||||
self.assertIsNotNone(model_fn_ops.loss)
|
||||
with session.Session() as sess:
|
||||
self.assertEqual(1, sess.run(model_fn_ops.loss))
|
||||
|
||||
|
||||
class MultiHeadTest(test.TestCase):
|
||||
|
||||
def testInvalidHeads(self):
|
||||
@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase):
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib.multi_head((head1, head2))
|
||||
head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
|
||||
head = head_lib.multi_head((head1, head2, head3))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertIsNone(model_fn_ops.output_alternatives)
|
||||
|
||||
with session.Session() as sess:
|
||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
||||
self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
|
||||
|
||||
def testTrain_withHeadWeights(self):
|
||||
head1 = head_lib.multi_class_head(
|
||||
|
Loading…
Reference in New Issue
Block a user