Do not calculate train_op in eval mode for _BinaryLogisticHead.

Change: 140555724
This commit is contained in:
A. Unique TensorFlower 2016-11-29 18:40:08 -08:00 committed by TensorFlower Gardener
parent 35b6050b57
commit 347d3ef2a8
2 changed files with 24 additions and 1 deletions

View File

@ -508,7 +508,7 @@ class _BinaryLogisticHead(_Head):
eval_metric_ops = None
else:
loss = self._training_loss(features, labels, logits)
train_op = (None if train_op_fn is None
train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
else self._train_op(loss, labels, train_op_fn))
eval_metric_ops = self._eval_metric_ops(features, labels, predictions)

View File

@ -36,6 +36,11 @@ class RegressionModelHeadTest(tf.test.TestCase):
_noop_train_op, logits=prediction)
self.assertAlmostEqual(5. / 3, sess.run(model_fn_ops.loss))
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.EVAL,
_noop_train_op, logits=prediction)
self.assertIsNone(model_fn_ops.train_op)
def testRegressionWithWeights(self):
head = head_lib._regression_head(
weight_column_name="label_weight")
@ -74,6 +79,11 @@ class MultiLabelModelHeadTest(tf.test.TestCase):
_noop_train_op, logits=logits)
self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss))
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.EVAL,
_noop_train_op, logits=logits)
self.assertIsNone(model_fn_ops.train_op)
def testMultiLabelWithWeight(self):
head = head_lib._multi_label_head(
n_classes=3, weight_column_name="label_weight")
@ -101,6 +111,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
_noop_train_op, logits=logits)
self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss),
delta=1e-6)
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.EVAL,
_noop_train_op, logits=logits)
self.assertIsNone(model_fn_ops.train_op)
def testErrorInSparseTensorLabels(self):
head = head_lib._multi_class_head(n_classes=2)
@ -141,6 +155,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
tf.contrib.learn.ModeKeys.TRAIN,
_noop_train_op, logits=logits)
self.assertAlmostEqual(1.5514446, sess.run(model_fn_ops.loss))
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.EVAL,
_noop_train_op, logits=logits)
self.assertIsNone(model_fn_ops.train_op)
def testMultiClassWithWeight(self):
head = head_lib._multi_class_head(
@ -178,6 +196,11 @@ class BinarySvmModelHeadTest(tf.test.TestCase):
with tf.Session() as sess:
self.assertAlmostEqual(0.25, sess.run(model_fn_ops.loss))
model_fn_ops = head.head_ops({}, labels,
tf.contrib.learn.ModeKeys.EVAL,
_noop_train_op, logits=predictions)
self.assertIsNone(model_fn_ops.train_op)
def testBinarySVMWithWeights(self):
head = head_lib._binary_svm_head(
weight_column_name="weights")