Adds label_vocab to all canned Classifiers.
PiperOrigin-RevId: 158804380
This commit is contained in:
parent
5e06dbdc00
commit
b044a5d5e8
tensorflow/python/estimator/canned
@ -196,6 +196,7 @@ class DNNClassifier(estimator.Estimator):
|
||||
model_dir=None,
|
||||
n_classes=2,
|
||||
weight_feature_key=None,
|
||||
label_vocabulary=None,
|
||||
optimizer='Adagrad',
|
||||
activation_fn=nn.relu,
|
||||
dropout=None,
|
||||
@ -218,6 +219,13 @@ class DNNClassifier(estimator.Estimator):
|
||||
weight_feature_key: A string defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
label_vocabulary: A list of strings represents possible label values. If
|
||||
given, labels must be string type and have any value in
|
||||
`label_vocabulary`. If it is not given, that means labels are
|
||||
already encoded as integer or float within [0, 1] for `n_classes=2` and
|
||||
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
|
||||
Also there will be errors if vocabulary is not provided and labels are
|
||||
string.
|
||||
optimizer: An instance of `tf.Optimizer` used to train the model. If
|
||||
`None`, will use an Adagrad optimizer.
|
||||
activation_fn: Activation function applied to each layer. If `None`, will
|
||||
@ -230,10 +238,12 @@ class DNNClassifier(estimator.Estimator):
|
||||
"""
|
||||
if n_classes == 2:
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
|
||||
weight_column=weight_feature_key)
|
||||
weight_column=weight_feature_key,
|
||||
label_vocabulary=label_vocabulary)
|
||||
else:
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
|
||||
n_classes, weight_column=weight_feature_key)
|
||||
n_classes, weight_column=weight_feature_key,
|
||||
label_vocabulary=label_vocabulary)
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_model_fn(
|
||||
features=features,
|
||||
|
@ -308,6 +308,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
dnn_dropout=None,
|
||||
n_classes=2,
|
||||
weight_feature_key=None,
|
||||
label_vocabulary=None,
|
||||
input_layer_partitioner=None,
|
||||
config=None):
|
||||
"""Initializes a DNNLinearCombinedClassifier instance.
|
||||
@ -337,6 +338,13 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
weight_feature_key: A string defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
label_vocabulary: A list of strings represents possible label values. If
|
||||
given, labels must be string type and have any value in
|
||||
`label_vocabulary`. If it is not given, that means labels are
|
||||
already encoded as integer or float within [0, 1] for `n_classes=2` and
|
||||
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
|
||||
Also there will be errors if vocabulary is not provided and labels are
|
||||
string.
|
||||
input_layer_partitioner: Partitioner for input layer. Defaults to
|
||||
`min_max_variable_partitioner` with `min_slice_size` 64 << 20.
|
||||
config: RunConfig object to configure the runtime settings.
|
||||
@ -354,11 +362,13 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
'must be defined.')
|
||||
if n_classes == 2:
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
|
||||
weight_column=weight_feature_key)
|
||||
weight_column=weight_feature_key,
|
||||
label_vocabulary=label_vocabulary)
|
||||
else:
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
|
||||
n_classes,
|
||||
weight_column=weight_feature_key)
|
||||
weight_column=weight_feature_key,
|
||||
label_vocabulary=label_vocabulary)
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_linear_combined_model_fn(
|
||||
features=features,
|
||||
|
@ -325,6 +325,7 @@ def _dnn_classifier_fn(
|
||||
model_dir=None,
|
||||
n_classes=2,
|
||||
weight_feature_key=None,
|
||||
label_vocabulary=None,
|
||||
optimizer='Adagrad',
|
||||
config=None,
|
||||
input_layer_partitioner=None):
|
||||
@ -335,6 +336,7 @@ def _dnn_classifier_fn(
|
||||
dnn_optimizer=optimizer,
|
||||
n_classes=n_classes,
|
||||
weight_feature_key=weight_feature_key,
|
||||
label_vocabulary=label_vocabulary,
|
||||
input_layer_partitioner=input_layer_partitioner,
|
||||
config=config)
|
||||
|
||||
|
@ -624,7 +624,7 @@ class BaseDNNClassifierPredictTest(object):
|
||||
writer_cache.FileWriterCache.clear()
|
||||
shutil.rmtree(self._model_dir)
|
||||
|
||||
def test_one_dim(self):
|
||||
def _test_one_dim(self, label_vocabulary, label_output_fn):
|
||||
"""Asserts predictions for one-dimensional input and logits."""
|
||||
create_checkpoint(
|
||||
(([[.6, .5]], [.1, -.1]), ([[1., .8], [-.8, -1.]], [.2, -.2]),
|
||||
@ -634,6 +634,7 @@ class BaseDNNClassifierPredictTest(object):
|
||||
|
||||
dnn_classifier = self._dnn_classifier_fn(
|
||||
hidden_units=(2, 2),
|
||||
label_vocabulary=label_vocabulary,
|
||||
feature_columns=(feature_column.numeric_column('x'),),
|
||||
model_dir=self._model_dir)
|
||||
input_fn = numpy_io.numpy_input_fn(
|
||||
@ -654,10 +655,20 @@ class BaseDNNClassifierPredictTest(object):
|
||||
0.11105597], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
|
||||
self.assertAllClose([0],
|
||||
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
|
||||
self.assertAllEqual([b'0'],
|
||||
self.assertAllEqual([label_output_fn(0)],
|
||||
predictions[prediction_keys.PredictionKeys.CLASSES])
|
||||
|
||||
def test_multi_dim(self):
|
||||
def test_one_dim_without_label_vocabulary(self):
|
||||
self._test_one_dim(label_vocabulary=None,
|
||||
label_output_fn=lambda x: ('%s' % x).encode())
|
||||
|
||||
def test_one_dim_with_label_vocabulary(self):
|
||||
n_classes = 2
|
||||
self._test_one_dim(
|
||||
label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],
|
||||
label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
|
||||
|
||||
def _test_multi_dim_with_3_classes(self, label_vocabulary, label_output_fn):
|
||||
"""Asserts predictions for multi-dimensional input and logits."""
|
||||
create_checkpoint(
|
||||
(([[.6, .5], [-.6, -.5]], [.1, -.1]),
|
||||
@ -669,6 +680,7 @@ class BaseDNNClassifierPredictTest(object):
|
||||
dnn_classifier = self._dnn_classifier_fn(
|
||||
hidden_units=(2, 2),
|
||||
feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
|
||||
label_vocabulary=label_vocabulary,
|
||||
n_classes=3,
|
||||
model_dir=self._model_dir)
|
||||
input_fn = numpy_io.numpy_input_fn(
|
||||
@ -698,7 +710,19 @@ class BaseDNNClassifierPredictTest(object):
|
||||
self.assertAllEqual(
|
||||
[1], predictions[prediction_keys.PredictionKeys.CLASS_IDS])
|
||||
self.assertAllEqual(
|
||||
[b'1'], predictions[prediction_keys.PredictionKeys.CLASSES])
|
||||
[label_output_fn(1)],
|
||||
predictions[prediction_keys.PredictionKeys.CLASSES])
|
||||
|
||||
def test_multi_dim_with_3_classes_but_no_label_vocab(self):
|
||||
self._test_multi_dim_with_3_classes(
|
||||
label_vocabulary=None,
|
||||
label_output_fn=lambda x: ('%s' % x).encode())
|
||||
|
||||
def test_multi_dim_with_3_classes_and_label_vocab(self):
|
||||
n_classes = 3
|
||||
self._test_multi_dim_with_3_classes(
|
||||
label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],
|
||||
label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
|
||||
|
||||
|
||||
class BaseDNNRegressorPredictTest(object):
|
||||
|
@ -179,10 +179,12 @@ class LinearClassifier(estimator.Estimator):
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
label_vocabulary: A list of strings represents possible label values. If
|
||||
it is not given, that means labels are already encoded within [0, 1]. If
|
||||
given, labels must be string type and have any value in
|
||||
`label_vocabulary`. Also there will be errors if vocabulary is not
|
||||
provided and labels are string.
|
||||
`label_vocabulary`. If it is not given, that means labels are
|
||||
already encoded as integer or float within [0, 1] for `n_classes=2` and
|
||||
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
|
||||
Also there will be errors if vocabulary is not provided and labels are
|
||||
string.
|
||||
optimizer: The optimizer used to train the model. If specified, it should
|
||||
be either an instance of `tf.Optimizer` or the SDCAOptimizer. If `None`,
|
||||
the Ftrl optimizer will be used.
|
||||
|
Loading…
Reference in New Issue
Block a user