Add label-vocab support to binary logistic head.

Add assertion that binary classifier label is in range [0., 1.]
Fixed Classifier Integration tests.

PiperOrigin-RevId: 158307521
This commit is contained in:
Mustafa Ispir 2017-06-07 13:15:04 -07:00 committed by TensorFlower Gardener
parent f105df0478
commit edb5fed7fc
4 changed files with 220 additions and 150 deletions

View File

@ -322,6 +322,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
if self._model_dir:
shutil.rmtree(self._model_dir)
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
n_classes, batch_size):
@ -363,12 +366,13 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn."""
n_classes = 2
n_classes = 3
input_dimension = 2
batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension)
y_data = np.reshape(data[:batch_size], (batch_size, 1))
y_data = self._as_label(np.reshape(data[:batch_size], (batch_size, 1)))
# learn y = x
train_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data},
@ -401,9 +405,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
input_dimension = 1
n_classes = 2
batch_size = 10
data = np.linspace(0., 2., batch_size, dtype=np.float32)
data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
x = pd.DataFrame({'x': data})
y = pd.Series(data)
y = pd.Series(self._as_label(data))
train_input_fn = pandas_io.pandas_input_fn(
x=x,
y=y,
@ -431,25 +435,28 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 2
n_classes = 3
batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
data = np.linspace(0., n_classes-1., batch_size * input_dimension,
dtype=np.float32)
data = data.reshape(batch_size, input_dimension)
serialized_examples = []
for datum in data:
example = example_pb2.Example(features=feature_pb2.Features(
feature={
'x': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum)),
'y': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum[:1])),
'x':
feature_pb2.Feature(float_list=feature_pb2.FloatList(
value=datum)),
'y':
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
value=self._as_label(datum[:1]))),
}))
serialized_examples.append(example.SerializeToString())
feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
'y': parsing_ops.FixedLenFeature([1], dtypes.float32),
'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
}
def _train_input_fn():
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)

View File

@ -300,12 +300,18 @@ class DNNClassifierPredictTest(test.TestCase):
# logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597
# probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597]
# class_ids = argmax(probabilities) = [0]
self.assertAllClose({
prediction_keys.PredictionKeys.LOGITS: [-2.08],
prediction_keys.PredictionKeys.LOGISTIC: [0.11105597],
prediction_keys.PredictionKeys.PROBABILITIES: [0.88894403, 0.11105597],
prediction_keys.PredictionKeys.CLASS_IDS: [0],
}, next(dnn_classifier.predict(input_fn=input_fn)))
predictions = next(dnn_classifier.predict(input_fn=input_fn))
self.assertAllClose([-2.08],
predictions[prediction_keys.PredictionKeys.LOGITS])
self.assertAllClose([0.11105597],
predictions[prediction_keys.PredictionKeys.LOGISTIC])
self.assertAllClose(
[0.88894403,
0.11105597], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
self.assertAllClose([0],
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
self.assertAllEqual([b'0'],
predictions[prediction_keys.PredictionKeys.CLASSES])
def test_multi_dim(self):
"""Asserts predictions for multi-dimensional input and logits."""
@ -535,6 +541,9 @@ class DNNClassifierIntegrationTest(test.TestCase):
if self._model_dir:
shutil.rmtree(self._model_dir)
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
n_classes, batch_size):
@ -572,12 +581,13 @@ class DNNClassifierIntegrationTest(test.TestCase):
def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn."""
n_classes = 2
n_classes = 3
input_dimension = 2
batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension)
y_data = np.reshape(data[:batch_size], (batch_size, 1))
y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
# learn y = x
train_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data},
@ -608,11 +618,11 @@ class DNNClassifierIntegrationTest(test.TestCase):
if not HAS_PANDAS:
return
input_dimension = 1
n_classes = 2
n_classes = 3
batch_size = 10
data = np.linspace(0., 2., batch_size, dtype=np.float32)
data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
x = pd.DataFrame({'x': data})
y = pd.Series(data)
y = pd.Series(self._as_label(data))
train_input_fn = pandas_io.pandas_input_fn(
x=x,
y=y,
@ -640,25 +650,28 @@ class DNNClassifierIntegrationTest(test.TestCase):
def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 2
n_classes = 3
batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
data = data.reshape(batch_size, input_dimension)
serialized_examples = []
for datum in data:
example = example_pb2.Example(features=feature_pb2.Features(
feature={
'x': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum)),
'y': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum[:1])),
'x':
feature_pb2.Feature(float_list=feature_pb2.FloatList(
value=datum)),
'y':
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
value=self._as_label(datum[:1]))),
}))
serialized_examples.append(example.SerializeToString())
feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
'y': parsing_ops.FixedLenFeature([1], dtypes.float32),
'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
}
def _train_input_fn():
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)

View File

@ -302,7 +302,8 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
Raises:
ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
"""
if label_vocabulary is not None and not isinstance(label_vocabulary, list):
if label_vocabulary is not None and not isinstance(label_vocabulary,
(list, tuple)):
raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
type(label_vocabulary)))
@ -356,14 +357,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
label_ids = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels)
assert_less = check_ops.assert_less(
label_ids,
ops.convert_to_tensor(self._n_classes, dtype=label_ids.dtype),
message='Label IDs must < n_classes')
assert_greater = check_ops.assert_non_negative(
label_ids, message='Label Ids must >= 0')
with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(label_ids)
return _assert_range(label_ids, self._n_classes)
def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None):
@ -459,7 +453,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_feature_key=None, thresholds=None):
weight_feature_key=None, thresholds=None, label_vocabulary=None):
"""Creates a `Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@ -475,6 +469,11 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
generated for each threshold value. This threshold is applied to the
logistic values to determine the binary classification (i.e., above the
threshold is `true`, below is `false`.
label_vocabulary: A list of strings represents possible label values. If it
is not given, that means labels are already encoded within [0, 1]. If
given, labels must be string type and have any value in
`label_vocabulary`. Also there will be errors if vocabulary is not
provided and labels are string.
Returns:
An instance of `Head` for binary classification.
@ -483,50 +482,81 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
ValueError: if `thresholds` contains a value outside of `(0, 1)`.
"""
thresholds = tuple(thresholds) if thresholds else tuple()
if label_vocabulary is not None and not isinstance(label_vocabulary,
(list, tuple)):
raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
type(label_vocabulary)))
for threshold in thresholds:
if (threshold <= 0.0) or (threshold >= 1.0):
raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,))
return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
weight_feature_key=weight_feature_key, thresholds=thresholds)
weight_feature_key=weight_feature_key,
thresholds=thresholds,
label_vocabulary=label_vocabulary)
class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
"""See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`."""
def __init__(self, weight_feature_key=None, thresholds=None):
def __init__(self,
weight_feature_key=None,
thresholds=None,
label_vocabulary=None):
self._weight_feature_key = weight_feature_key
self._thresholds = thresholds
self._label_vocabulary = label_vocabulary
@property
def logits_dimension(self):
return 1
def _eval_metric_ops(
self, labels, logits, logistic, scores, classes, unweighted_loss,
weights=None):
with ops.name_scope(
None, 'metrics',
(labels, logits, logistic, scores, classes, unweighted_loss, weights)):
def _eval_metric_ops(self,
labels,
logits,
logistic,
scores,
class_ids,
unweighted_loss,
weights=None):
with ops.name_scope(None, 'metrics', (labels, logits, logistic, scores,
class_ids, unweighted_loss, weights)):
keys = metric_keys.MetricKeys
labels_mean = _indicator_labels_mean(
labels=labels, weights=weights, name=keys.LABEL_MEAN)
metric_ops = {
# Estimator already adds a metric for loss.
keys.LOSS_MEAN: metrics_lib.mean(
unweighted_loss, weights=weights, name=keys.LOSS_MEAN),
keys.ACCURACY: metrics_lib.accuracy(
labels=labels, predictions=classes, weights=weights,
name=keys.ACCURACY),
keys.PREDICTION_MEAN: _predictions_mean(
predictions=logistic, weights=weights, name=keys.PREDICTION_MEAN),
keys.LABEL_MEAN: labels_mean,
keys.ACCURACY_BASELINE: _accuracy_baseline(labels_mean),
keys.AUC: _auc(
labels=labels, predictions=logistic, weights=weights,
name=keys.AUC),
keys.AUC_PR: _auc(
labels=labels, predictions=logistic, weights=weights, curve='PR',
name=keys.AUC_PR)
keys.LOSS_MEAN:
metrics_lib.mean(
unweighted_loss, weights=weights, name=keys.LOSS_MEAN),
keys.ACCURACY:
metrics_lib.accuracy(
labels=labels,
predictions=class_ids,
weights=weights,
name=keys.ACCURACY),
keys.PREDICTION_MEAN:
_predictions_mean(
predictions=logistic,
weights=weights,
name=keys.PREDICTION_MEAN),
keys.LABEL_MEAN:
labels_mean,
keys.ACCURACY_BASELINE:
_accuracy_baseline(labels_mean),
keys.AUC:
_auc(
labels=labels,
predictions=logistic,
weights=weights,
name=keys.AUC),
keys.AUC_PR:
_auc(
labels=labels,
predictions=logistic,
weights=weights,
curve='PR',
name=keys.AUC_PR)
}
for threshold in self._thresholds:
accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
@ -559,27 +589,39 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
two_class_logits = array_ops.concat(
(array_ops.zeros_like(logits), logits), 1, name='two_class_logits')
scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES)
classes = array_ops.reshape(
class_ids = array_ops.reshape(
math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes')
if self._label_vocabulary:
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=self._label_vocabulary, name='class_string_lookup')
classes = table.lookup(class_ids)
else:
classes = string_ops.as_string(class_ids, name='str_classes')
predictions = {
pred_keys.LOGITS: logits,
pred_keys.LOGISTIC: logistic,
pred_keys.PROBABILITIES: scores,
pred_keys.CLASS_IDS: classes
pred_keys.CLASS_IDS: class_ids,
pred_keys.CLASSES: classes,
}
if mode == model_fn.ModeKeys.PREDICT:
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={'': export_output.ClassificationOutput(
scores=scores,
# `ClassificationOutput` requires string classes.
# TODO(ptucker): Support label_keys.
classes=string_ops.as_string(classes, name='str_classes'))})
export_outputs={
'':
export_output.ClassificationOutput(
scores=scores, classes=classes)
})
# Eval.
labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)),
self.logits_dimension)
labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
if self._label_vocabulary is not None:
labels = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels)
labels = math_ops.to_float(labels)
labels = _assert_range(labels, 2)
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits, name='loss')
weights = (
@ -598,7 +640,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
logits=logits,
logistic=logistic,
scores=scores,
classes=classes,
class_ids=class_ids,
unweighted_loss=unweighted_loss,
weights=weights))
@ -721,3 +763,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
predictions=predictions,
loss=training_loss,
train_op=train_op_fn(training_loss))
def _assert_range(labels, n_classes):
assert_less = check_ops.assert_less(
labels,
ops.convert_to_tensor(n_classes, dtype=labels.dtype),
message='Label IDs must < n_classes')
assert_greater = check_ops.assert_non_negative(
labels, message='Label IDs must >= 0')
with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(labels)

View File

@ -206,7 +206,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
})
with self.test_session():
with self.assertRaisesOpError('Label Ids must >= 0'):
with self.assertRaisesOpError('Label IDs must >= 0'):
spec.loss.eval({
labels_placeholder: labels_2x1_with_negative_id,
logits_placeholder: logits_2x3
@ -743,8 +743,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertEqual(1, head.logits_dimension)
# Both logits and labels should be shape (batch_size, 1).
values_2x1 = np.array(((43.,), (44.,),))
values_3x1 = np.array(((45.,), (46.,), (47.,),))
values_2x1 = np.array(((0.,), (1.,),))
values_3x1 = np.array(((0.,), (1.,), (0.,),))
# Static shape.
with self.assertRaisesRegexp(
@ -788,28 +788,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertEqual(1, head.logits_dimension)
# Create estimator spec.
logits = np.array(((45,), (-41,),), dtype=np.int32)
logits = [[45.], [-41.]]
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.),), dtype=np.float32),
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,)), dtype=np.int64),
}
# Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNone(spec.train_op)
@ -821,7 +806,37 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
self.assertAllClose(expected_predictions, sess.run(spec.predictions))
predictions = sess.run(spec.predictions)
self.assertAllClose(logits,
predictions[prediction_keys.PredictionKeys.LOGITS])
self.assertAllClose(
_sigmoid(np.array(logits)),
predictions[prediction_keys.PredictionKeys.LOGISTIC])
self.assertAllClose(
[[0., 1.],
[1., 0.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
self.assertAllClose([[1], [0]],
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
self.assertAllEqual([[b'1'], [b'0']],
predictions[prediction_keys.PredictionKeys.CLASSES])
def test_predict_with_vocabulary_list(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
label_vocabulary=['aang', 'iroh'])
logits = [[1.], [0.]]
expected_classes = [[b'iroh'], [b'aang']]
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_classes,
sess.run(spec.predictions[prediction_keys.PredictionKeys.CLASSES]))
def test_eval(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
@ -834,17 +849,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits,
labels=np.array(((1,), (1,),), dtype=np.int32))
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.),), dtype=np.float32),
# TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,)), dtype=np.int64),
}
keys = metric_keys.MetricKeys
expected_metrics = {
# loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41
@ -859,10 +863,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
}
# Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
@ -875,15 +875,34 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
predictions, loss, metrics = sess.run((
spec.predictions, spec.loss, update_ops))
self.assertAllClose(expected_predictions, predictions)
loss, metrics = sess.run((spec.loss, update_ops))
self.assertAllClose(41., loss)
# Check results of both update (in `metrics`) and value ops.
self.assertAllClose(expected_metrics, metrics)
self.assertAllClose(
expected_metrics, {k: value_ops[k].eval() for k in value_ops})
def test_eval_with_vocabulary_list(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
label_vocabulary=['aang', 'iroh'])
# Create estimator spec.
logits = np.array(((45,), (-41,),), dtype=np.float32)
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.float32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=[[b'iroh'], [b'iroh']])
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
sess.run(update_ops)
self.assertAllClose(1. / 2,
value_ops[metric_keys.MetricKeys.ACCURACY].eval())
def test_eval_with_thresholds(self):
thresholds = [0.25, 0.5, 0.75]
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
@ -942,23 +961,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=np.array(((1,), (1,),), dtype=np.float64),
train_op_fn=_train_op_fn)
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.),), dtype=np.float32),
# TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,)), dtype=np.int64),
}
# Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNotNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNotNone(spec.train_op)
@ -969,9 +972,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run((
spec.predictions, spec.loss, spec.train_op, spec.scaffold.summary_op))
self.assertAllClose(expected_predictions, predictions)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
spec.scaffold.summary_op))
self.assertAllClose(expected_loss, loss)
self.assertEqual(expected_train_result, train_result)
_assert_simple_summaries(self, {
@ -995,28 +997,23 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.), (0., 1.)), dtype=np.float32),
# TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,), (1,)), dtype=np.int64),
}
# Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
# Assert predictions, loss, and metrics.
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllClose(expected_predictions, sess.run(spec.predictions))
predictions = sess.run(spec.predictions)
self.assertAllClose(
logits.astype(np.float32),
predictions[prediction_keys.PredictionKeys.LOGITS])
self.assertAllClose(
_sigmoid(logits).astype(np.float32),
predictions[prediction_keys.PredictionKeys.LOGISTIC])
self.assertAllClose(
[[0., 1.], [1., 0.],
[0., 1.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
self.assertAllClose([[1], [0], [1]],
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
self.assertAllEqual([[b'1'], [b'0'], [b'1']],
predictions[prediction_keys.PredictionKeys.CLASSES])
def test_weighted_multi_example_eval(self):
"""3 examples, 1 batch."""