Supports label_vocabulary in multi_label_head.
PiperOrigin-RevId: 169260799
This commit is contained in:
parent
3588ff74d7
commit
e9bd04ad1f
tensorflow/contrib/estimator
@ -103,9 +103,15 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:metrics",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python/estimator:export_output",
|
||||
"//tensorflow/python/estimator:head",
|
||||
@ -124,9 +130,13 @@ py_test(
|
||||
deps = [
|
||||
":head",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/estimator:metric_keys",
|
||||
|
@ -23,10 +23,12 @@ from tensorflow.python.estimator.canned import head as head_lib
|
||||
from tensorflow.python.estimator.canned import metric_keys
|
||||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics as metrics_lib
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
@ -138,10 +140,10 @@ def regression_head(weight_column=None,
|
||||
head_name=head_name)
|
||||
|
||||
|
||||
# TODO(roumposg): Support label_vocabulary.
|
||||
def multi_label_head(n_classes,
|
||||
weight_column=None,
|
||||
thresholds=None,
|
||||
label_vocabulary=None,
|
||||
head_name=None):
|
||||
"""Creates a `_Head` for multi-label classification.
|
||||
|
||||
@ -164,6 +166,11 @@ def multi_label_head(n_classes,
|
||||
and recall metrics are evaluated for each threshold value. The threshold
|
||||
is applied to the predicted probabilities, i.e. above the threshold is
|
||||
`true`, below is `false`.
|
||||
label_vocabulary: A list of strings represents possible label values. If it
|
||||
is not given, that means labels are already encoded as integer within
|
||||
[0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
|
||||
string type and have any value in `label_vocabulary`. Also there will be
|
||||
errors if vocabulary is not provided and labels are string.
|
||||
head_name: name of the head. If provided, summary and metrics keys will be
|
||||
suffixed by `"/" + head_name`.
|
||||
|
||||
@ -182,9 +189,18 @@ def multi_label_head(n_classes,
|
||||
if (threshold <= 0.0) or (threshold >= 1.0):
|
||||
raise ValueError(
|
||||
'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
|
||||
if label_vocabulary is not None:
|
||||
if not isinstance(label_vocabulary, (list, tuple)):
|
||||
raise ValueError(
|
||||
'label_vocabulary must be a list or tuple. '
|
||||
'Given type: {}'.format(type(label_vocabulary)))
|
||||
if len(label_vocabulary) != n_classes:
|
||||
raise ValueError(
|
||||
'Length of label_vocabulary must be n_classes ({}). '
|
||||
'Given: {}'.format(n_classes, len(label_vocabulary)))
|
||||
return _MultiLabelHead(
|
||||
n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
|
||||
head_name=head_name)
|
||||
label_vocabulary=label_vocabulary, head_name=head_name)
|
||||
|
||||
|
||||
class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
|
||||
@ -194,10 +210,12 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
|
||||
n_classes,
|
||||
weight_column=None,
|
||||
thresholds=None,
|
||||
label_vocabulary=None,
|
||||
head_name=None):
|
||||
self._n_classes = n_classes
|
||||
self._weight_column = weight_column
|
||||
self._thresholds = thresholds
|
||||
self._label_vocabulary = label_vocabulary
|
||||
self._head_name = head_name
|
||||
|
||||
@property
|
||||
@ -206,8 +224,18 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
|
||||
|
||||
def _process_labels(self, labels):
|
||||
if isinstance(labels, sparse_tensor.SparseTensor):
|
||||
if labels.dtype == dtypes.string:
|
||||
label_ids_values = lookup_ops.index_table_from_tensor(
|
||||
vocabulary_list=tuple(self._label_vocabulary),
|
||||
name='class_id_lookup').lookup(labels.values)
|
||||
label_ids = sparse_tensor.SparseTensor(
|
||||
indices=labels.indices,
|
||||
values=label_ids_values,
|
||||
dense_shape=labels.dense_shape)
|
||||
else:
|
||||
label_ids = labels
|
||||
return math_ops.to_int64(
|
||||
sparse_ops.sparse_to_indicator(labels, self._n_classes))
|
||||
sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
|
||||
msg = ('labels shape must be [batch_size, {}]. '
|
||||
'Given: ').format(self._n_classes)
|
||||
labels_shape = array_ops.shape(labels)
|
||||
@ -254,7 +282,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
|
||||
})
|
||||
|
||||
# Eval.
|
||||
unweighted_loss, _ = self.create_loss(
|
||||
unweighted_loss, processed_labels = self.create_loss(
|
||||
features=features, mode=mode, logits=logits, labels=labels)
|
||||
# Averages loss over classes.
|
||||
per_example_loss = math_ops.reduce_mean(
|
||||
@ -268,7 +296,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
|
||||
predictions=predictions,
|
||||
loss=training_loss,
|
||||
eval_metric_ops=self._eval_metric_ops(
|
||||
labels=labels,
|
||||
labels=processed_labels,
|
||||
probabilities=probabilities,
|
||||
weights=weights,
|
||||
per_example_loss=per_example_loss))
|
||||
|
@ -113,6 +113,19 @@ class MultiLabelHead(test.TestCase):
|
||||
r'thresholds must be in \(0, 1\) range\. Given: 1\.0'):
|
||||
head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0])
|
||||
|
||||
def test_label_vocabulary_dict(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'label_vocabulary must be a list or tuple\. '
|
||||
r'Given type: <(type|class) \'dict\'>'):
|
||||
head_lib.multi_label_head(n_classes=2, label_vocabulary={'foo': 'bar'})
|
||||
|
||||
def test_label_vocabulary_wrong_size(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
r'Length of label_vocabulary must be n_classes \(3\). Given: 2'):
|
||||
head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar'])
|
||||
|
||||
def test_predict(self):
|
||||
n_classes = 4
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
@ -219,35 +232,6 @@ class MultiLabelHead(test.TestCase):
|
||||
self.assertAllClose(
|
||||
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
|
||||
|
||||
def test_eval_create_loss_sparse_labels(self):
|
||||
"""Tests head.create_loss for eval mode and sparse labels."""
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
values=[0, 0, 1],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
expected_labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# For large logits, this is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits
|
||||
expected_unweighted_loss = np.array(
|
||||
[[10., 10.], [15., 0.]], dtype=np.float32)
|
||||
actual_unweighted_loss, actual_labels = head.create_loss(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
with self.test_session():
|
||||
_initialize_variables(self, monitored_session.Scaffold())
|
||||
self.assertAllEqual(expected_labels, actual_labels.eval())
|
||||
self.assertAllClose(
|
||||
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
|
||||
|
||||
def test_eval_create_loss_labels_wrong_shape(self):
|
||||
"""Tests head.create_loss for eval mode when labels has the wrong shape."""
|
||||
n_classes = 2
|
||||
@ -273,35 +257,13 @@ class MultiLabelHead(test.TestCase):
|
||||
actual_unweighted_loss.eval(
|
||||
{labels_placeholder: np.array([1, 1], dtype=np.int64)})
|
||||
|
||||
def test_eval(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# Average over classes, and sum over examples.
|
||||
expected_loss = (
|
||||
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
|
||||
)
|
||||
|
||||
def _test_eval(self, head, logits, labels, expected_loss, expected_metrics):
|
||||
spec = head.create_estimator_spec(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over examples.
|
||||
keys.LOSS_MEAN: expected_loss / 2,
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
|
||||
@ -325,6 +287,100 @@ class MultiLabelHead(test.TestCase):
|
||||
rtol=tol,
|
||||
atol=tol)
|
||||
|
||||
def test_eval(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# Average over classes, and sum over examples.
|
||||
expected_loss = (
|
||||
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
|
||||
)
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over examples.
|
||||
keys.LOSS_MEAN: expected_loss / 2,
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
expected_loss=expected_loss,
|
||||
expected_metrics=expected_metrics)
|
||||
|
||||
def test_eval_sparse_labels(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
|
||||
# Equivalent to multi_hot = [[1, 0], [1, 1]]
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
values=[0, 0, 1],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# Average over classes, and sum over examples.
|
||||
expected_loss = (
|
||||
np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) /
|
||||
n_classes
|
||||
)
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over examples.
|
||||
keys.LOSS_MEAN: expected_loss / 2,
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
expected_loss=expected_loss,
|
||||
expected_metrics=expected_metrics)
|
||||
|
||||
def test_eval_with_label_vocabulary(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes, label_vocabulary=['class0', 'class1'])
|
||||
logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
|
||||
# Equivalent to multi_hot = [[1, 0], [1, 1]]
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
values=['class0', 'class0', 'class1'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# loss = labels * -log(sigmoid(logits)) +
|
||||
# (1 - labels) * -log(1 - sigmoid(logits))
|
||||
# Average over classes, and sum over examples.
|
||||
expected_loss = (
|
||||
np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) /
|
||||
n_classes
|
||||
)
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over examples.
|
||||
keys.LOSS_MEAN: expected_loss / 2,
|
||||
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
|
||||
# this assert tests that the algorithm remains consistent.
|
||||
keys.AUC: 0.3333,
|
||||
keys.AUC_PR: 0.7639,
|
||||
}
|
||||
self._test_eval(
|
||||
head=head,
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
expected_loss=expected_loss,
|
||||
expected_metrics=expected_metrics)
|
||||
|
||||
def test_eval_with_thresholds(self):
|
||||
n_classes = 2
|
||||
thresholds = [0.25, 0.5, 0.75]
|
||||
@ -339,12 +395,6 @@ class MultiLabelHead(test.TestCase):
|
||||
np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes
|
||||
)
|
||||
|
||||
spec = head.create_estimator_spec(
|
||||
features={'x': np.array(((42,),), dtype=np.int32)},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=labels)
|
||||
|
||||
keys = metric_keys.MetricKeys
|
||||
expected_metrics = {
|
||||
# Average loss over examples.
|
||||
@ -364,28 +414,12 @@ class MultiLabelHead(test.TestCase):
|
||||
keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3.,
|
||||
}
|
||||
|
||||
# Assert spec contains expected tensors.
|
||||
self.assertIsNotNone(spec.loss)
|
||||
self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
|
||||
self.assertIsNone(spec.train_op)
|
||||
self.assertIsNone(spec.export_outputs)
|
||||
_assert_no_hooks(self, spec)
|
||||
|
||||
# Assert predictions, loss, and metrics.
|
||||
tol = 1e-3
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
self.assertIsNone(spec.scaffold.summary_op)
|
||||
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
|
||||
update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
|
||||
loss, metrics = sess.run((spec.loss, update_ops))
|
||||
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
|
||||
# Check results of both update (in `metrics`) and value ops.
|
||||
self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
|
||||
self.assertAllClose(
|
||||
expected_metrics, {k: value_ops[k].eval() for k in value_ops},
|
||||
rtol=tol,
|
||||
atol=tol)
|
||||
self._test_eval(
|
||||
head=head,
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
expected_loss=expected_loss,
|
||||
expected_metrics=expected_metrics)
|
||||
|
||||
def test_eval_with_weights(self):
|
||||
n_classes = 2
|
||||
@ -466,18 +500,7 @@ class MultiLabelHead(test.TestCase):
|
||||
self.assertAllClose(
|
||||
expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4)
|
||||
|
||||
def test_train(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes)
|
||||
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# For large logits, sigmoid cross entropy loss is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits =>
|
||||
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
|
||||
# Average over classes, sum over weights.
|
||||
expected_loss = 17.5
|
||||
def _test_train(self, head, logits, labels, expected_loss):
|
||||
expected_train_result = 'my_train_op'
|
||||
def _train_op_fn(loss):
|
||||
return string_ops.string_join(
|
||||
@ -514,6 +537,54 @@ class MultiLabelHead(test.TestCase):
|
||||
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
|
||||
}, summary_str, tol)
|
||||
|
||||
def test_train(self):
|
||||
head = head_lib.multi_label_head(n_classes=2)
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
|
||||
# For large logits, sigmoid cross entropy loss is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits =>
|
||||
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
|
||||
# Average over classes, sum over weights.
|
||||
expected_loss = 17.5
|
||||
self._test_train(
|
||||
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
|
||||
|
||||
def test_train_sparse_labels(self):
|
||||
head = head_lib.multi_label_head(n_classes=2)
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
# Equivalent to multi_hot = [[1, 0], [1, 1]]
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
values=[0, 0, 1],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
# For large logits, sigmoid cross entropy loss is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits =>
|
||||
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
|
||||
# Average over classes, sum over weights.
|
||||
expected_loss = 17.5
|
||||
self._test_train(
|
||||
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
|
||||
|
||||
def test_train_with_label_vocabulary(self):
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=2, label_vocabulary=['class0', 'class1'])
|
||||
logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
|
||||
# Equivalent to multi_hot = [[1, 0], [1, 1]]
|
||||
labels = sparse_tensor.SparseTensor(
|
||||
values=['class0', 'class0', 'class1'],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
# For large logits, sigmoid cross entropy loss is approximated as:
|
||||
# loss = labels * (logits < 0) * (-logits) +
|
||||
# (1 - labels) * (logits > 0) * logits =>
|
||||
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
|
||||
# Average over classes, sum over weights.
|
||||
expected_loss = 17.5
|
||||
self._test_train(
|
||||
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
|
||||
|
||||
def test_train_with_weights(self):
|
||||
n_classes = 2
|
||||
head = head_lib.multi_label_head(n_classes, weight_column='label_weights')
|
||||
|
Loading…
Reference in New Issue
Block a user