Add loss_only_head to hold additional loss terms for multi_head setup
PiperOrigin-RevId: 157875934
This commit is contained in:
parent
7cdcd0cca2
commit
9e25c68ad1
@ -308,6 +308,7 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_rea
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import loss_only_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
||||||
|
@ -429,6 +429,23 @@ def multi_label_head(n_classes,
|
|||||||
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
||||||
|
|
||||||
|
|
||||||
|
def loss_only_head(loss_fn, head_name=None):
|
||||||
|
"""Creates a Head that contains only loss terms.
|
||||||
|
|
||||||
|
Loss only head holds additional loss terms to be added to other heads and
|
||||||
|
usually represents additional regularization terms in the objective function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_fn: a function that takes no argument and returns a list of
|
||||||
|
scalar tensors.
|
||||||
|
head_name: a name for for the head.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `Head` to hold the additional losses.
|
||||||
|
"""
|
||||||
|
return _LossOnlyHead(loss_fn, head_name=head_name)
|
||||||
|
|
||||||
|
|
||||||
def multi_head(heads, loss_weights=None):
|
def multi_head(heads, loss_weights=None):
|
||||||
"""Creates a MultiHead stemming from same logits/hidden layer.
|
"""Creates a MultiHead stemming from same logits/hidden layer.
|
||||||
|
|
||||||
@ -1406,6 +1423,80 @@ class _MultiLabelHead(_SingleHead):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
class _LossOnlyHead(Head):
|
||||||
|
"""`Head` implementation for additional loss terms.
|
||||||
|
|
||||||
|
This class only holds loss terms unrelated to any other heads (labels),
|
||||||
|
e.g. regularization.
|
||||||
|
|
||||||
|
Common usage:
|
||||||
|
This is oftem combine with other heads in a multi head setup.
|
||||||
|
```python
|
||||||
|
head = multi_head([
|
||||||
|
head1, head2, loss_only_head('regularizer', regularizer)])
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_fn, head_name=None):
|
||||||
|
self._loss_fn = loss_fn
|
||||||
|
self.head_name = head_name or "loss_only_head"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits_dimension(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def create_model_fn_ops(self,
|
||||||
|
features,
|
||||||
|
mode,
|
||||||
|
labels=None,
|
||||||
|
train_op_fn=None,
|
||||||
|
logits=None,
|
||||||
|
logits_input=None,
|
||||||
|
scope=None):
|
||||||
|
"""See `_Head.create_model_fn_ops`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Not been used.
|
||||||
|
mode: Estimator's `ModeKeys`.
|
||||||
|
labels: Labels `Tensor`, or `dict` of same.
|
||||||
|
train_op_fn: Function that takes a scalar loss and returns an op to
|
||||||
|
optimize with the loss.
|
||||||
|
logits: Not been used.
|
||||||
|
logits_input: Not been used.
|
||||||
|
scope: Optional scope for variable_scope. If provided, will be passed to
|
||||||
|
all heads. Most users will want to set this to `None`, so each head
|
||||||
|
constructs a separate variable_scope according to its `head_name`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `ModelFnOps` object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `mode` is not recognition.
|
||||||
|
"""
|
||||||
|
_check_mode_valid(mode)
|
||||||
|
loss = None
|
||||||
|
train_op = None
|
||||||
|
if mode != model_fn.ModeKeys.INFER:
|
||||||
|
with variable_scope.variable_scope(scope, default_name=self.head_name):
|
||||||
|
loss = self._loss_fn()
|
||||||
|
if isinstance(loss, list):
|
||||||
|
loss = math_ops.add_n(loss)
|
||||||
|
logging_ops.scalar_summary(
|
||||||
|
_summary_key(self.head_name, mkey.LOSS), loss)
|
||||||
|
if mode == model_fn.ModeKeys.TRAIN:
|
||||||
|
if train_op_fn is None:
|
||||||
|
raise ValueError("train_op_fn can not be None in TRAIN mode")
|
||||||
|
with ops.name_scope(None, "train_op", (loss,)):
|
||||||
|
train_op = train_op_fn(loss)
|
||||||
|
|
||||||
|
return model_fn.ModelFnOps(
|
||||||
|
mode=mode,
|
||||||
|
loss=loss,
|
||||||
|
train_op=train_op,
|
||||||
|
predictions={},
|
||||||
|
eval_metric_ops={})
|
||||||
|
|
||||||
|
|
||||||
class _MultiHead(Head):
|
class _MultiHead(Head):
|
||||||
"""`Head` implementation for multi objective learning.
|
"""`Head` implementation for multi objective learning.
|
||||||
|
|
||||||
@ -1525,6 +1616,9 @@ class _MultiHead(Head):
|
|||||||
if isinstance(logits, dict):
|
if isinstance(logits, dict):
|
||||||
head_logits_pairs = []
|
head_logits_pairs = []
|
||||||
for head in self._heads:
|
for head in self._heads:
|
||||||
|
if isinstance(head, _LossOnlyHead):
|
||||||
|
head_logits_pairs.append((head, None))
|
||||||
|
else:
|
||||||
head_logits_pairs.append((head, logits[head.head_name]))
|
head_logits_pairs.append((head, logits[head.head_name]))
|
||||||
else:
|
else:
|
||||||
# Split logits for each head.
|
# Split logits for each head.
|
||||||
@ -1606,6 +1700,8 @@ class _MultiHead(Head):
|
|||||||
predictions = {}
|
predictions = {}
|
||||||
output_alternatives = {}
|
output_alternatives = {}
|
||||||
for head, m in zip(self._heads, all_model_fn_ops):
|
for head, m in zip(self._heads, all_model_fn_ops):
|
||||||
|
if isinstance(head, _LossOnlyHead):
|
||||||
|
continue
|
||||||
head_name = head.head_name
|
head_name = head.head_name
|
||||||
output_alternatives[head_name] = m.output_alternatives[head_name]
|
output_alternatives[head_name] = m.output_alternatives[head_name]
|
||||||
for k, v in m.predictions.items():
|
for k, v in m.predictions.items():
|
||||||
|
@ -1638,6 +1638,21 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
|
|
||||||
|
class LossOnlyHead(test.TestCase):
|
||||||
|
|
||||||
|
def testNoPredictionsAndNoMetrics(self):
|
||||||
|
head = head_lib.loss_only_head(lambda: 1, head_name="const")
|
||||||
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
|
features={},
|
||||||
|
mode=model_fn.ModeKeys.TRAIN,
|
||||||
|
train_op_fn=head_lib.no_op_train_fn)
|
||||||
|
self.assertDictEqual(model_fn_ops.predictions, {})
|
||||||
|
self.assertDictEqual(model_fn_ops.eval_metric_ops, {})
|
||||||
|
self.assertIsNotNone(model_fn_ops.loss)
|
||||||
|
with session.Session() as sess:
|
||||||
|
self.assertEqual(1, sess.run(model_fn_ops.loss))
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadTest(test.TestCase):
|
class MultiHeadTest(test.TestCase):
|
||||||
|
|
||||||
def testInvalidHeads(self):
|
def testInvalidHeads(self):
|
||||||
@ -1672,7 +1687,8 @@ class MultiHeadTest(test.TestCase):
|
|||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib.multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib.multi_head((head1, head2))
|
head3 = head_lib.loss_only_head(lambda: 1.0, head_name="const")
|
||||||
|
head = head_lib.multi_head((head1, head2, head3))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
@ -1691,7 +1707,7 @@ class MultiHeadTest(test.TestCase):
|
|||||||
self.assertIsNone(model_fn_ops.output_alternatives)
|
self.assertIsNone(model_fn_ops.output_alternatives)
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
self.assertAlmostEqual(3.224, sess.run(model_fn_ops.loss), places=3)
|
||||||
|
|
||||||
def testTrain_withHeadWeights(self):
|
def testTrain_withHeadWeights(self):
|
||||||
head1 = head_lib.multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user