Add label-vocab support to binary logistic head.

Add assertion that binary classifier label is in range [0., 1.]
Fixed Classifier Integration tests.

PiperOrigin-RevId: 158307521
This commit is contained in:
Mustafa Ispir 2017-06-07 13:15:04 -07:00 committed by TensorFlower Gardener
parent f105df0478
commit edb5fed7fc
4 changed files with 220 additions and 150 deletions

View File

@ -322,6 +322,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
if self._model_dir: if self._model_dir:
shutil.rmtree(self._model_dir) shutil.rmtree(self._model_dir)
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
def _test_complete_flow( def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
n_classes, batch_size): n_classes, batch_size):
@ -363,12 +366,13 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def test_numpy_input_fn(self): def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn.""" """Tests complete flow with numpy_input_fn."""
n_classes = 2 n_classes = 3
input_dimension = 2 input_dimension = 2
batch_size = 10 batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension) x_data = data.reshape(batch_size, input_dimension)
y_data = np.reshape(data[:batch_size], (batch_size, 1)) y_data = self._as_label(np.reshape(data[:batch_size], (batch_size, 1)))
# learn y = x # learn y = x
train_input_fn = numpy_io.numpy_input_fn( train_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data}, x={'x': x_data},
@ -401,9 +405,9 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
input_dimension = 1 input_dimension = 1
n_classes = 2 n_classes = 2
batch_size = 10 batch_size = 10
data = np.linspace(0., 2., batch_size, dtype=np.float32) data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
x = pd.DataFrame({'x': data}) x = pd.DataFrame({'x': data})
y = pd.Series(data) y = pd.Series(self._as_label(data))
train_input_fn = pandas_io.pandas_input_fn( train_input_fn = pandas_io.pandas_input_fn(
x=x, x=x,
y=y, y=y,
@ -431,25 +435,28 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def test_input_fn_from_parse_example(self): def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example.""" """Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2 input_dimension = 2
n_classes = 2 n_classes = 3
batch_size = 10 batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) data = np.linspace(0., n_classes-1., batch_size * input_dimension,
dtype=np.float32)
data = data.reshape(batch_size, input_dimension) data = data.reshape(batch_size, input_dimension)
serialized_examples = [] serialized_examples = []
for datum in data: for datum in data:
example = example_pb2.Example(features=feature_pb2.Features( example = example_pb2.Example(features=feature_pb2.Features(
feature={ feature={
'x': feature_pb2.Feature( 'x':
float_list=feature_pb2.FloatList(value=datum)), feature_pb2.Feature(float_list=feature_pb2.FloatList(
'y': feature_pb2.Feature( value=datum)),
float_list=feature_pb2.FloatList(value=datum[:1])), 'y':
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
value=self._as_label(datum[:1]))),
})) }))
serialized_examples.append(example.SerializeToString()) serialized_examples.append(example.SerializeToString())
feature_spec = { feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32), 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
'y': parsing_ops.FixedLenFeature([1], dtypes.float32), 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
} }
def _train_input_fn(): def _train_input_fn():
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)

View File

@ -300,12 +300,18 @@ class DNNClassifierPredictTest(test.TestCase):
# logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597 # logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597
# probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597] # probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597]
# class_ids = argmax(probabilities) = [0] # class_ids = argmax(probabilities) = [0]
self.assertAllClose({ predictions = next(dnn_classifier.predict(input_fn=input_fn))
prediction_keys.PredictionKeys.LOGITS: [-2.08], self.assertAllClose([-2.08],
prediction_keys.PredictionKeys.LOGISTIC: [0.11105597], predictions[prediction_keys.PredictionKeys.LOGITS])
prediction_keys.PredictionKeys.PROBABILITIES: [0.88894403, 0.11105597], self.assertAllClose([0.11105597],
prediction_keys.PredictionKeys.CLASS_IDS: [0], predictions[prediction_keys.PredictionKeys.LOGISTIC])
}, next(dnn_classifier.predict(input_fn=input_fn))) self.assertAllClose(
[0.88894403,
0.11105597], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
self.assertAllClose([0],
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
self.assertAllEqual([b'0'],
predictions[prediction_keys.PredictionKeys.CLASSES])
def test_multi_dim(self): def test_multi_dim(self):
"""Asserts predictions for multi-dimensional input and logits.""" """Asserts predictions for multi-dimensional input and logits."""
@ -535,6 +541,9 @@ class DNNClassifierIntegrationTest(test.TestCase):
if self._model_dir: if self._model_dir:
shutil.rmtree(self._model_dir) shutil.rmtree(self._model_dir)
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
def _test_complete_flow( def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
n_classes, batch_size): n_classes, batch_size):
@ -572,12 +581,13 @@ class DNNClassifierIntegrationTest(test.TestCase):
def test_numpy_input_fn(self): def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn.""" """Tests complete flow with numpy_input_fn."""
n_classes = 2 n_classes = 3
input_dimension = 2 input_dimension = 2
batch_size = 10 batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
x_data = data.reshape(batch_size, input_dimension) x_data = data.reshape(batch_size, input_dimension)
y_data = np.reshape(data[:batch_size], (batch_size, 1)) y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
# learn y = x # learn y = x
train_input_fn = numpy_io.numpy_input_fn( train_input_fn = numpy_io.numpy_input_fn(
x={'x': x_data}, x={'x': x_data},
@ -608,11 +618,11 @@ class DNNClassifierIntegrationTest(test.TestCase):
if not HAS_PANDAS: if not HAS_PANDAS:
return return
input_dimension = 1 input_dimension = 1
n_classes = 2 n_classes = 3
batch_size = 10 batch_size = 10
data = np.linspace(0., 2., batch_size, dtype=np.float32) data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
x = pd.DataFrame({'x': data}) x = pd.DataFrame({'x': data})
y = pd.Series(data) y = pd.Series(self._as_label(data))
train_input_fn = pandas_io.pandas_input_fn( train_input_fn = pandas_io.pandas_input_fn(
x=x, x=x,
y=y, y=y,
@ -640,25 +650,28 @@ class DNNClassifierIntegrationTest(test.TestCase):
def test_input_fn_from_parse_example(self): def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example.""" """Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2 input_dimension = 2
n_classes = 2 n_classes = 3
batch_size = 10 batch_size = 10
data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32) data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
data = data.reshape(batch_size, input_dimension) data = data.reshape(batch_size, input_dimension)
serialized_examples = [] serialized_examples = []
for datum in data: for datum in data:
example = example_pb2.Example(features=feature_pb2.Features( example = example_pb2.Example(features=feature_pb2.Features(
feature={ feature={
'x': feature_pb2.Feature( 'x':
float_list=feature_pb2.FloatList(value=datum)), feature_pb2.Feature(float_list=feature_pb2.FloatList(
'y': feature_pb2.Feature( value=datum)),
float_list=feature_pb2.FloatList(value=datum[:1])), 'y':
feature_pb2.Feature(int64_list=feature_pb2.Int64List(
value=self._as_label(datum[:1]))),
})) }))
serialized_examples.append(example.SerializeToString()) serialized_examples.append(example.SerializeToString())
feature_spec = { feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32), 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
'y': parsing_ops.FixedLenFeature([1], dtypes.float32), 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
} }
def _train_input_fn(): def _train_input_fn():
feature_map = parsing_ops.parse_example(serialized_examples, feature_spec) feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)

View File

@ -302,7 +302,8 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
Raises: Raises:
ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
""" """
if label_vocabulary is not None and not isinstance(label_vocabulary, list): if label_vocabulary is not None and not isinstance(label_vocabulary,
(list, tuple)):
raise ValueError('label_vocabulary should be a list. Given type: {}'.format( raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
type(label_vocabulary))) type(label_vocabulary)))
@ -356,14 +357,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
label_ids = lookup_ops.index_table_from_tensor( label_ids = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary), vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels) name='class_id_lookup').lookup(labels)
assert_less = check_ops.assert_less( return _assert_range(label_ids, self._n_classes)
label_ids,
ops.convert_to_tensor(self._n_classes, dtype=label_ids.dtype),
message='Label IDs must < n_classes')
assert_greater = check_ops.assert_non_negative(
label_ids, message='Label Ids must >= 0')
with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(label_ids)
def create_estimator_spec( def create_estimator_spec(
self, features, mode, logits, labels=None, train_op_fn=None): self, features, mode, logits, labels=None, train_op_fn=None):
@ -459,7 +453,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss( def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_feature_key=None, thresholds=None): weight_feature_key=None, thresholds=None, label_vocabulary=None):
"""Creates a `Head` for single label binary classification. """Creates a `Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss. This head uses `sigmoid_cross_entropy_with_logits` loss.
@ -475,6 +469,11 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
generated for each threshold value. This threshold is applied to the generated for each threshold value. This threshold is applied to the
logistic values to determine the binary classification (i.e., above the logistic values to determine the binary classification (i.e., above the
threshold is `true`, below is `false`. threshold is `true`, below is `false`.
label_vocabulary: A list of strings represents possible label values. If it
is not given, that means labels are already encoded within [0, 1]. If
given, labels must be string type and have any value in
`label_vocabulary`. Also there will be errors if vocabulary is not
provided and labels are string.
Returns: Returns:
An instance of `Head` for binary classification. An instance of `Head` for binary classification.
@ -483,49 +482,80 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
ValueError: if `thresholds` contains a value outside of `(0, 1)`. ValueError: if `thresholds` contains a value outside of `(0, 1)`.
""" """
thresholds = tuple(thresholds) if thresholds else tuple() thresholds = tuple(thresholds) if thresholds else tuple()
if label_vocabulary is not None and not isinstance(label_vocabulary,
(list, tuple)):
raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
type(label_vocabulary)))
for threshold in thresholds: for threshold in thresholds:
if (threshold <= 0.0) or (threshold >= 1.0): if (threshold <= 0.0) or (threshold >= 1.0):
raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,)) raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,))
return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
weight_feature_key=weight_feature_key, thresholds=thresholds) weight_feature_key=weight_feature_key,
thresholds=thresholds,
label_vocabulary=label_vocabulary)
class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
"""See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`.""" """See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`."""
def __init__(self, weight_feature_key=None, thresholds=None): def __init__(self,
weight_feature_key=None,
thresholds=None,
label_vocabulary=None):
self._weight_feature_key = weight_feature_key self._weight_feature_key = weight_feature_key
self._thresholds = thresholds self._thresholds = thresholds
self._label_vocabulary = label_vocabulary
@property @property
def logits_dimension(self): def logits_dimension(self):
return 1 return 1
def _eval_metric_ops( def _eval_metric_ops(self,
self, labels, logits, logistic, scores, classes, unweighted_loss, labels,
logits,
logistic,
scores,
class_ids,
unweighted_loss,
weights=None): weights=None):
with ops.name_scope( with ops.name_scope(None, 'metrics', (labels, logits, logistic, scores,
None, 'metrics', class_ids, unweighted_loss, weights)):
(labels, logits, logistic, scores, classes, unweighted_loss, weights)):
keys = metric_keys.MetricKeys keys = metric_keys.MetricKeys
labels_mean = _indicator_labels_mean( labels_mean = _indicator_labels_mean(
labels=labels, weights=weights, name=keys.LABEL_MEAN) labels=labels, weights=weights, name=keys.LABEL_MEAN)
metric_ops = { metric_ops = {
# Estimator already adds a metric for loss. # Estimator already adds a metric for loss.
keys.LOSS_MEAN: metrics_lib.mean( keys.LOSS_MEAN:
metrics_lib.mean(
unweighted_loss, weights=weights, name=keys.LOSS_MEAN), unweighted_loss, weights=weights, name=keys.LOSS_MEAN),
keys.ACCURACY: metrics_lib.accuracy( keys.ACCURACY:
labels=labels, predictions=classes, weights=weights, metrics_lib.accuracy(
labels=labels,
predictions=class_ids,
weights=weights,
name=keys.ACCURACY), name=keys.ACCURACY),
keys.PREDICTION_MEAN: _predictions_mean( keys.PREDICTION_MEAN:
predictions=logistic, weights=weights, name=keys.PREDICTION_MEAN), _predictions_mean(
keys.LABEL_MEAN: labels_mean, predictions=logistic,
keys.ACCURACY_BASELINE: _accuracy_baseline(labels_mean), weights=weights,
keys.AUC: _auc( name=keys.PREDICTION_MEAN),
labels=labels, predictions=logistic, weights=weights, keys.LABEL_MEAN:
labels_mean,
keys.ACCURACY_BASELINE:
_accuracy_baseline(labels_mean),
keys.AUC:
_auc(
labels=labels,
predictions=logistic,
weights=weights,
name=keys.AUC), name=keys.AUC),
keys.AUC_PR: _auc( keys.AUC_PR:
labels=labels, predictions=logistic, weights=weights, curve='PR', _auc(
labels=labels,
predictions=logistic,
weights=weights,
curve='PR',
name=keys.AUC_PR) name=keys.AUC_PR)
} }
for threshold in self._thresholds: for threshold in self._thresholds:
@ -559,27 +589,39 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
two_class_logits = array_ops.concat( two_class_logits = array_ops.concat(
(array_ops.zeros_like(logits), logits), 1, name='two_class_logits') (array_ops.zeros_like(logits), logits), 1, name='two_class_logits')
scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES) scores = nn.softmax(two_class_logits, name=pred_keys.PROBABILITIES)
classes = array_ops.reshape( class_ids = array_ops.reshape(
math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes') math_ops.argmax(two_class_logits, axis=1), (-1, 1), name='classes')
if self._label_vocabulary:
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=self._label_vocabulary, name='class_string_lookup')
classes = table.lookup(class_ids)
else:
classes = string_ops.as_string(class_ids, name='str_classes')
predictions = { predictions = {
pred_keys.LOGITS: logits, pred_keys.LOGITS: logits,
pred_keys.LOGISTIC: logistic, pred_keys.LOGISTIC: logistic,
pred_keys.PROBABILITIES: scores, pred_keys.PROBABILITIES: scores,
pred_keys.CLASS_IDS: classes pred_keys.CLASS_IDS: class_ids,
pred_keys.CLASSES: classes,
} }
if mode == model_fn.ModeKeys.PREDICT: if mode == model_fn.ModeKeys.PREDICT:
return model_fn.EstimatorSpec( return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT, mode=model_fn.ModeKeys.PREDICT,
predictions=predictions, predictions=predictions,
export_outputs={'': export_output.ClassificationOutput( export_outputs={
scores=scores, '':
# `ClassificationOutput` requires string classes. export_output.ClassificationOutput(
# TODO(ptucker): Support label_keys. scores=scores, classes=classes)
classes=string_ops.as_string(classes, name='str_classes'))}) })
# Eval. # Eval.
labels = _check_labels(_maybe_expand_dim(math_ops.to_float(labels)), labels = _check_labels(_maybe_expand_dim(labels), self.logits_dimension)
self.logits_dimension) if self._label_vocabulary is not None:
labels = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels)
labels = math_ops.to_float(labels)
labels = _assert_range(labels, 2)
unweighted_loss = nn.sigmoid_cross_entropy_with_logits( unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits, name='loss') labels=labels, logits=logits, name='loss')
weights = ( weights = (
@ -598,7 +640,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
logits=logits, logits=logits,
logistic=logistic, logistic=logistic,
scores=scores, scores=scores,
classes=classes, class_ids=class_ids,
unweighted_loss=unweighted_loss, unweighted_loss=unweighted_loss,
weights=weights)) weights=weights))
@ -721,3 +763,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
predictions=predictions, predictions=predictions,
loss=training_loss, loss=training_loss,
train_op=train_op_fn(training_loss)) train_op=train_op_fn(training_loss))
def _assert_range(labels, n_classes):
assert_less = check_ops.assert_less(
labels,
ops.convert_to_tensor(n_classes, dtype=labels.dtype),
message='Label IDs must < n_classes')
assert_greater = check_ops.assert_non_negative(
labels, message='Label IDs must >= 0')
with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(labels)

View File

@ -206,7 +206,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
}) })
with self.test_session(): with self.test_session():
with self.assertRaisesOpError('Label Ids must >= 0'): with self.assertRaisesOpError('Label IDs must >= 0'):
spec.loss.eval({ spec.loss.eval({
labels_placeholder: labels_2x1_with_negative_id, labels_placeholder: labels_2x1_with_negative_id,
logits_placeholder: logits_2x3 logits_placeholder: logits_2x3
@ -743,8 +743,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertEqual(1, head.logits_dimension) self.assertEqual(1, head.logits_dimension)
# Both logits and labels should be shape (batch_size, 1). # Both logits and labels should be shape (batch_size, 1).
values_2x1 = np.array(((43.,), (44.,),)) values_2x1 = np.array(((0.,), (1.,),))
values_3x1 = np.array(((45.,), (46.,), (47.,),)) values_3x1 = np.array(((0.,), (1.,), (0.,),))
# Static shape. # Static shape.
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -788,28 +788,13 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertEqual(1, head.logits_dimension) self.assertEqual(1, head.logits_dimension)
# Create estimator spec. # Create estimator spec.
logits = np.array(((45,), (-41,),), dtype=np.int32) logits = [[45.], [-41.]]
spec = head.create_estimator_spec( spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)}, features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.PREDICT, mode=model_fn.ModeKeys.PREDICT,
logits=logits) logits=logits)
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.),), dtype=np.float32),
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,)), dtype=np.int64),
}
# Assert spec contains expected tensors. # Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNone(spec.loss) self.assertIsNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops) self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNone(spec.train_op) self.assertIsNone(spec.train_op)
@ -821,7 +806,37 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
_initialize_variables(self, spec.scaffold) _initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op) self.assertIsNone(spec.scaffold.summary_op)
self.assertAllClose(expected_predictions, sess.run(spec.predictions)) predictions = sess.run(spec.predictions)
self.assertAllClose(logits,
predictions[prediction_keys.PredictionKeys.LOGITS])
self.assertAllClose(
_sigmoid(np.array(logits)),
predictions[prediction_keys.PredictionKeys.LOGISTIC])
self.assertAllClose(
[[0., 1.],
[1., 0.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
self.assertAllClose([[1], [0]],
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
self.assertAllEqual([[b'1'], [b'0']],
predictions[prediction_keys.PredictionKeys.CLASSES])
def test_predict_with_vocabulary_list(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
label_vocabulary=['aang', 'iroh'])
logits = [[1.], [0.]]
expected_classes = [[b'iroh'], [b'aang']]
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_classes,
sess.run(spec.predictions[prediction_keys.PredictionKeys.CLASSES]))
def test_eval(self): def test_eval(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
@ -834,17 +849,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits=logits, logits=logits,
labels=np.array(((1,), (1,),), dtype=np.int32)) labels=np.array(((1,), (1,),), dtype=np.int32))
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.),), dtype=np.float32),
# TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,)), dtype=np.int64),
}
keys = metric_keys.MetricKeys keys = metric_keys.MetricKeys
expected_metrics = { expected_metrics = {
# loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41 # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41
@ -859,10 +863,6 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
} }
# Assert spec contains expected tensors. # Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNotNone(spec.loss) self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op) self.assertIsNone(spec.train_op)
@ -875,15 +875,34 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNone(spec.scaffold.summary_op) self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
predictions, loss, metrics = sess.run(( loss, metrics = sess.run((spec.loss, update_ops))
spec.predictions, spec.loss, update_ops))
self.assertAllClose(expected_predictions, predictions)
self.assertAllClose(41., loss) self.assertAllClose(41., loss)
# Check results of both update (in `metrics`) and value ops. # Check results of both update (in `metrics`) and value ops.
self.assertAllClose(expected_metrics, metrics) self.assertAllClose(expected_metrics, metrics)
self.assertAllClose( self.assertAllClose(
expected_metrics, {k: value_ops[k].eval() for k in value_ops}) expected_metrics, {k: value_ops[k].eval() for k in value_ops})
def test_eval_with_vocabulary_list(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
label_vocabulary=['aang', 'iroh'])
# Create estimator spec.
logits = np.array(((45,), (-41,),), dtype=np.float32)
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.float32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=[[b'iroh'], [b'iroh']])
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
sess.run(update_ops)
self.assertAllClose(1. / 2,
value_ops[metric_keys.MetricKeys.ACCURACY].eval())
def test_eval_with_thresholds(self): def test_eval_with_thresholds(self):
thresholds = [0.25, 0.5, 0.75] thresholds = [0.25, 0.5, 0.75]
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
@ -942,23 +961,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
labels=np.array(((1,), (1,),), dtype=np.float64), labels=np.array(((1,), (1,),), dtype=np.float64),
train_op_fn=_train_op_fn) train_op_fn=_train_op_fn)
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.),), dtype=np.float32),
# TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,)), dtype=np.int64),
}
# Assert spec contains expected tensors. # Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
self.assertIsNotNone(spec.loss) self.assertIsNotNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops) self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNotNone(spec.train_op) self.assertIsNotNone(spec.train_op)
@ -969,9 +972,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
_initialize_variables(self, spec.scaffold) _initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op) self.assertIsNotNone(spec.scaffold.summary_op)
predictions, loss, train_result, summary_str = sess.run(( loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
spec.predictions, spec.loss, spec.train_op, spec.scaffold.summary_op)) spec.scaffold.summary_op))
self.assertAllClose(expected_predictions, predictions)
self.assertAllClose(expected_loss, loss) self.assertAllClose(expected_loss, loss)
self.assertEqual(expected_train_result, train_result) self.assertEqual(expected_train_result, train_result)
_assert_simple_summaries(self, { _assert_simple_summaries(self, {
@ -995,28 +997,23 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT, mode=model_fn.ModeKeys.PREDICT,
logits=logits) logits=logits)
expected_predictions = {
prediction_keys.PredictionKeys.LOGITS:
logits.astype(np.float32),
prediction_keys.PredictionKeys.LOGISTIC:
_sigmoid(logits).astype(np.float32),
prediction_keys.PredictionKeys.PROBABILITIES:
np.array(((0., 1.), (1., 0.), (0., 1.)), dtype=np.float32),
# TODO(ptucker): Should this be (batch_size, 1) instead of (batch_size)?
prediction_keys.PredictionKeys.CLASS_IDS:
np.array(((1,), (0,), (1,)), dtype=np.int64),
}
# Assert spec contains expected tensors.
self.assertItemsEqual(expected_predictions.keys(), spec.predictions.keys())
self.assertEqual(
{k: v.dtype for k, v in six.iteritems(expected_predictions)},
{k: v.dtype.as_numpy_dtype for k, v in six.iteritems(spec.predictions)})
# Assert predictions, loss, and metrics. # Assert predictions, loss, and metrics.
with self.test_session() as sess: with self.test_session() as sess:
_initialize_variables(self, spec.scaffold) _initialize_variables(self, spec.scaffold)
self.assertAllClose(expected_predictions, sess.run(spec.predictions)) predictions = sess.run(spec.predictions)
self.assertAllClose(
logits.astype(np.float32),
predictions[prediction_keys.PredictionKeys.LOGITS])
self.assertAllClose(
_sigmoid(logits).astype(np.float32),
predictions[prediction_keys.PredictionKeys.LOGISTIC])
self.assertAllClose(
[[0., 1.], [1., 0.],
[0., 1.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
self.assertAllClose([[1], [0], [1]],
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
self.assertAllEqual([[b'1'], [b'0'], [b'1']],
predictions[prediction_keys.PredictionKeys.CLASSES])
def test_weighted_multi_example_eval(self): def test_weighted_multi_example_eval(self):
"""3 examples, 1 batch.""" """3 examples, 1 batch."""