Supports label_vocabulary in multi_label_head.

PiperOrigin-RevId: 169260799
This commit is contained in:
A. Unique TensorFlower 2017-09-19 10:26:06 -07:00 committed by TensorFlower Gardener
parent 3588ff74d7
commit e9bd04ad1f
3 changed files with 206 additions and 97 deletions
tensorflow/contrib/estimator
BUILD
python/estimator

View File

@ -103,9 +103,15 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:metrics",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:summary",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:head",
@ -124,9 +130,13 @@ py_test(
deps = [
":head",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
"//tensorflow/python/estimator:metric_keys",

View File

@ -23,10 +23,12 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import sparse_ops
@ -138,10 +140,10 @@ def regression_head(weight_column=None,
head_name=head_name)
# TODO(roumposg): Support label_vocabulary.
def multi_label_head(n_classes,
weight_column=None,
thresholds=None,
label_vocabulary=None,
head_name=None):
"""Creates a `_Head` for multi-label classification.
@ -164,6 +166,11 @@ def multi_label_head(n_classes,
and recall metrics are evaluated for each threshold value. The threshold
is applied to the predicted probabilities, i.e. above the threshold is
`true`, below is `false`.
label_vocabulary: A list of strings represents possible label values. If it
is not given, that means labels are already encoded as integer within
[0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
string type and have any value in `label_vocabulary`. Also there will be
errors if vocabulary is not provided and labels are string.
head_name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + head_name`.
@ -182,9 +189,18 @@ def multi_label_head(n_classes,
if (threshold <= 0.0) or (threshold >= 1.0):
raise ValueError(
'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
if label_vocabulary is not None:
if not isinstance(label_vocabulary, (list, tuple)):
raise ValueError(
'label_vocabulary must be a list or tuple. '
'Given type: {}'.format(type(label_vocabulary)))
if len(label_vocabulary) != n_classes:
raise ValueError(
'Length of label_vocabulary must be n_classes ({}). '
'Given: {}'.format(n_classes, len(label_vocabulary)))
return _MultiLabelHead(
n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
head_name=head_name)
label_vocabulary=label_vocabulary, head_name=head_name)
class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
@ -194,10 +210,12 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
n_classes,
weight_column=None,
thresholds=None,
label_vocabulary=None,
head_name=None):
self._n_classes = n_classes
self._weight_column = weight_column
self._thresholds = thresholds
self._label_vocabulary = label_vocabulary
self._head_name = head_name
@property
@ -206,8 +224,18 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
def _process_labels(self, labels):
if isinstance(labels, sparse_tensor.SparseTensor):
if labels.dtype == dtypes.string:
label_ids_values = lookup_ops.index_table_from_tensor(
vocabulary_list=tuple(self._label_vocabulary),
name='class_id_lookup').lookup(labels.values)
label_ids = sparse_tensor.SparseTensor(
indices=labels.indices,
values=label_ids_values,
dense_shape=labels.dense_shape)
else:
label_ids = labels
return math_ops.to_int64(
sparse_ops.sparse_to_indicator(labels, self._n_classes))
sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
msg = ('labels shape must be [batch_size, {}]. '
'Given: ').format(self._n_classes)
labels_shape = array_ops.shape(labels)
@ -254,7 +282,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
})
# Eval.
unweighted_loss, _ = self.create_loss(
unweighted_loss, processed_labels = self.create_loss(
features=features, mode=mode, logits=logits, labels=labels)
# Averages loss over classes.
per_example_loss = math_ops.reduce_mean(
@ -268,7 +296,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
predictions=predictions,
loss=training_loss,
eval_metric_ops=self._eval_metric_ops(
labels=labels,
labels=processed_labels,
probabilities=probabilities,
weights=weights,
per_example_loss=per_example_loss))

View File

@ -113,6 +113,19 @@ class MultiLabelHead(test.TestCase):
r'thresholds must be in \(0, 1\) range\. Given: 1\.0'):
head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0])
def test_label_vocabulary_dict(self):
with self.assertRaisesRegexp(
ValueError,
r'label_vocabulary must be a list or tuple\. '
r'Given type: <(type|class) \'dict\'>'):
head_lib.multi_label_head(n_classes=2, label_vocabulary={'foo': 'bar'})
def test_label_vocabulary_wrong_size(self):
with self.assertRaisesRegexp(
ValueError,
r'Length of label_vocabulary must be n_classes \(3\). Given: 2'):
head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar'])
def test_predict(self):
n_classes = 4
head = head_lib.multi_label_head(n_classes)
@ -219,35 +232,6 @@ class MultiLabelHead(test.TestCase):
self.assertAllClose(
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
def test_eval_create_loss_sparse_labels(self):
"""Tests head.create_loss for eval mode and sparse labels."""
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = sparse_tensor.SparseTensor(
values=[0, 0, 1],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
expected_labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# For large logits, this is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits
expected_unweighted_loss = np.array(
[[10., 10.], [15., 0.]], dtype=np.float32)
actual_unweighted_loss, actual_labels = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllEqual(expected_labels, actual_labels.eval())
self.assertAllClose(
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
def test_eval_create_loss_labels_wrong_shape(self):
"""Tests head.create_loss for eval mode when labels has the wrong shape."""
n_classes = 2
@ -273,35 +257,13 @@ class MultiLabelHead(test.TestCase):
actual_unweighted_loss.eval(
{labels_placeholder: np.array([1, 1], dtype=np.int64)})
def test_eval(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# Average over classes, and sum over examples.
expected_loss = (
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
)
def _test_eval(self, head, logits, labels, expected_loss, expected_metrics):
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
keys.LOSS_MEAN: expected_loss / 2,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
}
# Assert spec contains expected tensors.
self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
@ -325,6 +287,100 @@ class MultiLabelHead(test.TestCase):
rtol=tol,
atol=tol)
def test_eval(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# Average over classes, and sum over examples.
expected_loss = (
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
keys.LOSS_MEAN: expected_loss / 2,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
}
self._test_eval(
head=head,
logits=logits,
labels=labels,
expected_loss=expected_loss,
expected_metrics=expected_metrics)
def test_eval_sparse_labels(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
# Equivalent to multi_hot = [[1, 0], [1, 1]]
labels = sparse_tensor.SparseTensor(
values=[0, 0, 1],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# Average over classes, and sum over examples.
expected_loss = (
np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) /
n_classes
)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
keys.LOSS_MEAN: expected_loss / 2,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
}
self._test_eval(
head=head,
logits=logits,
labels=labels,
expected_loss=expected_loss,
expected_metrics=expected_metrics)
def test_eval_with_label_vocabulary(self):
n_classes = 2
head = head_lib.multi_label_head(
n_classes, label_vocabulary=['class0', 'class1'])
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
# Equivalent to multi_hot = [[1, 0], [1, 1]]
labels = sparse_tensor.SparseTensor(
values=['class0', 'class0', 'class1'],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
# Average over classes, and sum over examples.
expected_loss = (
np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) /
n_classes
)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
keys.LOSS_MEAN: expected_loss / 2,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
keys.AUC_PR: 0.7639,
}
self._test_eval(
head=head,
logits=logits,
labels=labels,
expected_loss=expected_loss,
expected_metrics=expected_metrics)
def test_eval_with_thresholds(self):
n_classes = 2
thresholds = [0.25, 0.5, 0.75]
@ -339,12 +395,6 @@ class MultiLabelHead(test.TestCase):
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
)
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
@ -364,28 +414,12 @@ class MultiLabelHead(test.TestCase):
keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3.,
}
# Assert spec contains expected tensors.
self.assertIsNotNone(spec.loss)
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
# Assert predictions, loss, and metrics.
tol = 1e-3
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
loss, metrics = sess.run((spec.loss, update_ops))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
# Check results of both update (in `metrics`) and value ops.
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
self.assertAllClose(
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
rtol=tol,
atol=tol)
self._test_eval(
head=head,
logits=logits,
labels=labels,
expected_loss=expected_loss,
expected_metrics=expected_metrics)
def test_eval_with_weights(self):
n_classes = 2
@ -466,18 +500,7 @@ class MultiLabelHead(test.TestCase):
self.assertAllClose(
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
def test_train(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
# Average over classes, sum over weights.
expected_loss = 17.5
def _test_train(self, head, logits, labels, expected_loss):
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
@ -514,6 +537,54 @@ class MultiLabelHead(test.TestCase):
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
}, summary_str, tol)
def test_train(self):
head = head_lib.multi_label_head(n_classes=2)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
# Average over classes, sum over weights.
expected_loss = 17.5
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
def test_train_sparse_labels(self):
head = head_lib.multi_label_head(n_classes=2)
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
# Equivalent to multi_hot = [[1, 0], [1, 1]]
labels = sparse_tensor.SparseTensor(
values=[0, 0, 1],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
# Average over classes, sum over weights.
expected_loss = 17.5
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
def test_train_with_label_vocabulary(self):
head = head_lib.multi_label_head(
n_classes=2, label_vocabulary=['class0', 'class1'])
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
# Equivalent to multi_hot = [[1, 0], [1, 1]]
labels = sparse_tensor.SparseTensor(
values=['class0', 'class0', 'class1'],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
# For large logits, sigmoid cross entropy loss is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
# Average over classes, sum over weights.
expected_loss = 17.5
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
def test_train_with_weights(self):
n_classes = 2
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')