From e9bd04ad1f78aa180e2af349485e65c3a369963a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Sep 2017 10:26:06 -0700 Subject: [PATCH] Supports label_vocabulary in multi_label_head. PiperOrigin-RevId: 169260799 --- tensorflow/contrib/estimator/BUILD | 10 + .../estimator/python/estimator/head.py | 38 ++- .../estimator/python/estimator/head_test.py | 255 +++++++++++------- 3 files changed, 206 insertions(+), 97 deletions(-) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 4b2050c932a..b27dd05f322 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -103,9 +103,15 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", "//tensorflow/python:metrics", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:summary", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:head", @@ -124,9 +130,13 @@ py_test( deps = [ ":head", "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", "//tensorflow/python/estimator:metric_keys", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 164dfe6e82a..aee9409b1fb 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -23,10 +23,12 @@ from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import sparse_ops @@ -138,10 +140,10 @@ def regression_head(weight_column=None, head_name=head_name) -# TODO(roumposg): Support label_vocabulary. def multi_label_head(n_classes, weight_column=None, thresholds=None, + label_vocabulary=None, head_name=None): """Creates a `_Head` for multi-label classification. @@ -164,6 +166,11 @@ def multi_label_head(n_classes, and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is `true`, below is `false`. + label_vocabulary: A list of strings represents possible label values. If it + is not given, that means labels are already encoded as integer within + [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor + string type and have any value in `label_vocabulary`. Also there will be + errors if vocabulary is not provided and labels are string. head_name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + head_name`. @@ -182,9 +189,18 @@ def multi_label_head(n_classes, if (threshold <= 0.0) or (threshold >= 1.0): raise ValueError( 'thresholds must be in (0, 1) range. Given: {}'.format(threshold)) + if label_vocabulary is not None: + if not isinstance(label_vocabulary, (list, tuple)): + raise ValueError( + 'label_vocabulary must be a list or tuple. ' + 'Given type: {}'.format(type(label_vocabulary))) + if len(label_vocabulary) != n_classes: + raise ValueError( + 'Length of label_vocabulary must be n_classes ({}). ' + 'Given: {}'.format(n_classes, len(label_vocabulary))) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, - head_name=head_name) + label_vocabulary=label_vocabulary, head_name=head_name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -194,10 +210,12 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access n_classes, weight_column=None, thresholds=None, + label_vocabulary=None, head_name=None): self._n_classes = n_classes self._weight_column = weight_column self._thresholds = thresholds + self._label_vocabulary = label_vocabulary self._head_name = head_name @property @@ -206,8 +224,18 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access def _process_labels(self, labels): if isinstance(labels, sparse_tensor.SparseTensor): + if labels.dtype == dtypes.string: + label_ids_values = lookup_ops.index_table_from_tensor( + vocabulary_list=tuple(self._label_vocabulary), + name='class_id_lookup').lookup(labels.values) + label_ids = sparse_tensor.SparseTensor( + indices=labels.indices, + values=label_ids_values, + dense_shape=labels.dense_shape) + else: + label_ids = labels return math_ops.to_int64( - sparse_ops.sparse_to_indicator(labels, self._n_classes)) + sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) msg = ('labels shape must be [batch_size, {}]. ' 'Given: ').format(self._n_classes) labels_shape = array_ops.shape(labels) @@ -254,7 +282,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access }) # Eval. - unweighted_loss, _ = self.create_loss( + unweighted_loss, processed_labels = self.create_loss( features=features, mode=mode, logits=logits, labels=labels) # Averages loss over classes. per_example_loss = math_ops.reduce_mean( @@ -268,7 +296,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access predictions=predictions, loss=training_loss, eval_metric_ops=self._eval_metric_ops( - labels=labels, + labels=processed_labels, probabilities=probabilities, weights=weights, per_example_loss=per_example_loss)) diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 17753b4c9b0..009666d80a9 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -113,6 +113,19 @@ class MultiLabelHead(test.TestCase): r'thresholds must be in \(0, 1\) range\. Given: 1\.0'): head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0]) + def test_label_vocabulary_dict(self): + with self.assertRaisesRegexp( + ValueError, + r'label_vocabulary must be a list or tuple\. ' + r'Given type: <(type|class) \'dict\'>'): + head_lib.multi_label_head(n_classes=2, label_vocabulary={'foo': 'bar'}) + + def test_label_vocabulary_wrong_size(self): + with self.assertRaisesRegexp( + ValueError, + r'Length of label_vocabulary must be n_classes \(3\). Given: 2'): + head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar']) + def test_predict(self): n_classes = 4 head = head_lib.multi_label_head(n_classes) @@ -219,35 +232,6 @@ class MultiLabelHead(test.TestCase): self.assertAllClose( expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) - def test_eval_create_loss_sparse_labels(self): - """Tests head.create_loss for eval mode and sparse labels.""" - n_classes = 2 - head = head_lib.multi_label_head(n_classes) - - logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) - labels = sparse_tensor.SparseTensor( - values=[0, 0, 1], - indices=[[0, 0], [1, 0], [1, 1]], - dense_shape=[2, 2]) - expected_labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # loss = labels * -log(sigmoid(logits)) + - # (1 - labels) * -log(1 - sigmoid(logits)) - # For large logits, this is approximated as: - # loss = labels * (logits < 0) * (-logits) + - # (1 - labels) * (logits > 0) * logits - expected_unweighted_loss = np.array( - [[10., 10.], [15., 0.]], dtype=np.float32) - actual_unweighted_loss, actual_labels = head.create_loss( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.EVAL, - logits=logits, - labels=labels) - with self.test_session(): - _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllEqual(expected_labels, actual_labels.eval()) - self.assertAllClose( - expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) - def test_eval_create_loss_labels_wrong_shape(self): """Tests head.create_loss for eval mode when labels has the wrong shape.""" n_classes = 2 @@ -273,35 +257,13 @@ class MultiLabelHead(test.TestCase): actual_unweighted_loss.eval( {labels_placeholder: np.array([1, 1], dtype=np.int64)}) - def test_eval(self): - n_classes = 2 - head = head_lib.multi_label_head(n_classes) - - logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) - labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # loss = labels * -log(sigmoid(logits)) + - # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes - ) - + def _test_eval(self, head, logits, labels, expected_loss, expected_metrics): spec = head.create_estimator_spec( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels) - keys = metric_keys.MetricKeys - expected_metrics = { - # Average loss over examples. - keys.LOSS_MEAN: expected_loss / 2, - # auc and auc_pr cannot be reliably calculated for only 4 samples, but - # this assert tests that the algorithm remains consistent. - keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, - } - # Assert spec contains expected tensors. self.assertIsNotNone(spec.loss) self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) @@ -325,6 +287,100 @@ class MultiLabelHead(test.TestCase): rtol=tol, atol=tol) + def test_eval(self): + n_classes = 2 + head = head_lib.multi_label_head(n_classes) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Average over classes, and sum over examples. + expected_loss = ( + np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes + ) + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss / 2, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + + def test_eval_sparse_labels(self): + n_classes = 2 + head = head_lib.multi_label_head(n_classes) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + # Equivalent to multi_hot = [[1, 0], [1, 1]] + labels = sparse_tensor.SparseTensor( + values=[0, 0, 1], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Average over classes, and sum over examples. + expected_loss = ( + np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) / + n_classes + ) + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss / 2, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + + def test_eval_with_label_vocabulary(self): + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, label_vocabulary=['class0', 'class1']) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + # Equivalent to multi_hot = [[1, 0], [1, 1]] + labels = sparse_tensor.SparseTensor( + values=['class0', 'class0', 'class1'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Average over classes, and sum over examples. + expected_loss = ( + np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) / + n_classes + ) + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss / 2, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_thresholds(self): n_classes = 2 thresholds = [0.25, 0.5, 0.75] @@ -339,12 +395,6 @@ class MultiLabelHead(test.TestCase): np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes ) - spec = head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.EVAL, - logits=logits, - labels=labels) - keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. @@ -364,28 +414,12 @@ class MultiLabelHead(test.TestCase): keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3., } - # Assert spec contains expected tensors. - self.assertIsNotNone(spec.loss) - self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) - self.assertIsNone(spec.train_op) - self.assertIsNone(spec.export_outputs) - _assert_no_hooks(self, spec) - - # Assert predictions, loss, and metrics. - tol = 1e-3 - with self.test_session() as sess: - _initialize_variables(self, spec.scaffold) - self.assertIsNone(spec.scaffold.summary_op) - value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} - update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} - loss, metrics = sess.run((spec.loss, update_ops)) - self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) - # Check results of both update (in `metrics`) and value ops. - self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol) - self.assertAllClose( - expected_metrics, {k: value_ops[k].eval() for k in value_ops}, - rtol=tol, - atol=tol) + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) def test_eval_with_weights(self): n_classes = 2 @@ -466,18 +500,7 @@ class MultiLabelHead(test.TestCase): self.assertAllClose( expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) - def test_train(self): - n_classes = 2 - head = head_lib.multi_label_head(n_classes) - - logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) - labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - # For large logits, sigmoid cross entropy loss is approximated as: - # loss = labels * (logits < 0) * (-logits) + - # (1 - labels) * (logits > 0) * logits => - # expected_unweighted_loss = [[10., 10.], [15., 0.]] - # Average over classes, sum over weights. - expected_loss = 17.5 + def _test_train(self, head, logits, labels, expected_loss): expected_train_result = 'my_train_op' def _train_op_fn(loss): return string_ops.string_join( @@ -514,6 +537,54 @@ class MultiLabelHead(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, }, summary_str, tol) + def test_train(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over weights. + expected_loss = 17.5 + self._test_train( + head=head, logits=logits, labels=labels, expected_loss=expected_loss) + + def test_train_sparse_labels(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # Equivalent to multi_hot = [[1, 0], [1, 1]] + labels = sparse_tensor.SparseTensor( + values=[0, 0, 1], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over weights. + expected_loss = 17.5 + self._test_train( + head=head, logits=logits, labels=labels, expected_loss=expected_loss) + + def test_train_with_label_vocabulary(self): + head = head_lib.multi_label_head( + n_classes=2, label_vocabulary=['class0', 'class1']) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # Equivalent to multi_hot = [[1, 0], [1, 1]] + labels = sparse_tensor.SparseTensor( + values=['class0', 'class0', 'class1'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes, sum over weights. + expected_loss = 17.5 + self._test_train( + head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='label_weights')