Do not calculate train_op in eval mode for _BinaryLogisticHead.
Change: 140555724
This commit is contained in:
parent
35b6050b57
commit
347d3ef2a8
@ -508,7 +508,7 @@ class _BinaryLogisticHead(_Head):
|
||||
eval_metric_ops = None
|
||||
else:
|
||||
loss = self._training_loss(features, labels, logits)
|
||||
train_op = (None if train_op_fn is None
|
||||
train_op = (None if train_op_fn is None or mode == model_fn.ModeKeys.EVAL
|
||||
else self._train_op(loss, labels, train_op_fn))
|
||||
eval_metric_ops = self._eval_metric_ops(features, labels, predictions)
|
||||
|
||||
|
@ -36,6 +36,11 @@ class RegressionModelHeadTest(tf.test.TestCase):
|
||||
_noop_train_op, logits=prediction)
|
||||
self.assertAlmostEqual(5. / 3, sess.run(model_fn_ops.loss))
|
||||
|
||||
model_fn_ops = head.head_ops({}, labels,
|
||||
tf.contrib.learn.ModeKeys.EVAL,
|
||||
_noop_train_op, logits=prediction)
|
||||
self.assertIsNone(model_fn_ops.train_op)
|
||||
|
||||
def testRegressionWithWeights(self):
|
||||
head = head_lib._regression_head(
|
||||
weight_column_name="label_weight")
|
||||
@ -74,6 +79,11 @@ class MultiLabelModelHeadTest(tf.test.TestCase):
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertAlmostEqual(0.89985204, sess.run(model_fn_ops.loss))
|
||||
|
||||
model_fn_ops = head.head_ops({}, labels,
|
||||
tf.contrib.learn.ModeKeys.EVAL,
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertIsNone(model_fn_ops.train_op)
|
||||
|
||||
def testMultiLabelWithWeight(self):
|
||||
head = head_lib._multi_label_head(
|
||||
n_classes=3, weight_column_name="label_weight")
|
||||
@ -101,6 +111,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertAlmostEqual(0.81326175, sess.run(model_fn_ops.loss),
|
||||
delta=1e-6)
|
||||
model_fn_ops = head.head_ops({}, labels,
|
||||
tf.contrib.learn.ModeKeys.EVAL,
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertIsNone(model_fn_ops.train_op)
|
||||
|
||||
def testErrorInSparseTensorLabels(self):
|
||||
head = head_lib._multi_class_head(n_classes=2)
|
||||
@ -141,6 +155,10 @@ class MultiClassModelHeadTest(tf.test.TestCase):
|
||||
tf.contrib.learn.ModeKeys.TRAIN,
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertAlmostEqual(1.5514446, sess.run(model_fn_ops.loss))
|
||||
model_fn_ops = head.head_ops({}, labels,
|
||||
tf.contrib.learn.ModeKeys.EVAL,
|
||||
_noop_train_op, logits=logits)
|
||||
self.assertIsNone(model_fn_ops.train_op)
|
||||
|
||||
def testMultiClassWithWeight(self):
|
||||
head = head_lib._multi_class_head(
|
||||
@ -178,6 +196,11 @@ class BinarySvmModelHeadTest(tf.test.TestCase):
|
||||
with tf.Session() as sess:
|
||||
self.assertAlmostEqual(0.25, sess.run(model_fn_ops.loss))
|
||||
|
||||
model_fn_ops = head.head_ops({}, labels,
|
||||
tf.contrib.learn.ModeKeys.EVAL,
|
||||
_noop_train_op, logits=predictions)
|
||||
self.assertIsNone(model_fn_ops.train_op)
|
||||
|
||||
def testBinarySVMWithWeights(self):
|
||||
head = head_lib._binary_svm_head(
|
||||
weight_column_name="weights")
|
||||
|
Loading…
Reference in New Issue
Block a user