Add weight-column support to the heads.

PiperOrigin-RevId: 158409180
This commit is contained in:
Mustafa Ispir 2017-06-08 10:15:26 -07:00 committed by TensorFlower Gardener
parent 7fb52cd54c
commit d35cbbb447
6 changed files with 112 additions and 79 deletions

View File

@ -394,6 +394,7 @@ py_library(
"//tensorflow/python:string_ops", "//tensorflow/python:string_ops",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python:weights_broadcast_ops", "//tensorflow/python:weights_broadcast_ops",
"//tensorflow/python/feature_column",
"//tensorflow/python/ops/losses", "//tensorflow/python/ops/losses",
], ],
) )

View File

@ -230,10 +230,10 @@ class DNNClassifier(estimator.Estimator):
""" """
if n_classes == 2: if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_feature_key=weight_feature_key) weight_column=weight_feature_key)
else: else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes, weight_feature_key=weight_feature_key) n_classes, weight_column=weight_feature_key)
def _model_fn(features, labels, mode, config): def _model_fn(features, labels, mode, config):
return _dnn_model_fn( return _dnn_model_fn(
features=features, features=features,
@ -351,9 +351,10 @@ class DNNRegressor(estimator.Estimator):
features=features, features=features,
labels=labels, labels=labels,
mode=mode, mode=mode,
head=head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access head=head_lib. # pylint: disable=protected-access
_regression_head_with_mean_squared_error_loss(
label_dimension=label_dimension, label_dimension=label_dimension,
weight_feature_key=weight_feature_key), weight_column=weight_feature_key),
hidden_units=hidden_units, hidden_units=hidden_units,
feature_columns=tuple(feature_columns or []), feature_columns=tuple(feature_columns or []),
optimizer=optimizer, optimizer=optimizer,

View File

@ -307,6 +307,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
dnn_activation_fn=nn.relu, dnn_activation_fn=nn.relu,
dnn_dropout=None, dnn_dropout=None,
n_classes=2, n_classes=2,
weight_feature_key=None,
input_layer_partitioner=None, input_layer_partitioner=None,
config=None): config=None):
"""Initializes a DNNLinearCombinedClassifier instance. """Initializes a DNNLinearCombinedClassifier instance.
@ -333,6 +334,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
a given coordinate. a given coordinate.
n_classes: Number of label classes. Defaults to 2, namely binary n_classes: Number of label classes. Defaults to 2, namely binary
classification. Must be > 1. classification. Must be > 1.
weight_feature_key: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
input_layer_partitioner: Partitioner for input layer. Defaults to input_layer_partitioner: Partitioner for input layer. Defaults to
`min_max_variable_partitioner` with `min_slice_size` 64 << 20. `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: RunConfig object to configure the runtime settings. config: RunConfig object to configure the runtime settings.
@ -348,11 +352,12 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
raise ValueError('Either linear_feature_columns or dnn_feature_columns ' raise ValueError('Either linear_feature_columns or dnn_feature_columns '
'must be defined.') 'must be defined.')
if n_classes == 2: if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() # pylint: disable=protected-access head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_feature_key)
else: else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes) n_classes,
weight_column=weight_feature_key)
def _model_fn(features, labels, mode, config): def _model_fn(features, labels, mode, config):
return _dnn_linear_combined_model_fn( return _dnn_linear_combined_model_fn(
features=features, features=features,
@ -500,7 +505,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
head=head_lib. # pylint: disable=protected-access head=head_lib. # pylint: disable=protected-access
_regression_head_with_mean_squared_error_loss( _regression_head_with_mean_squared_error_loss(
label_dimension=label_dimension, label_dimension=label_dimension,
weight_feature_key=weight_feature_key), weight_column=weight_feature_key),
linear_feature_columns=linear_feature_columns, linear_feature_columns=linear_feature_columns,
linear_optimizer=linear_optimizer, linear_optimizer=linear_optimizer,
dnn_feature_columns=dnn_feature_columns, dnn_feature_columns=dnn_feature_columns,

View File

@ -26,6 +26,7 @@ from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.export import export_output
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
@ -278,7 +279,7 @@ def _recall_at_threshold(labels, predictions, weights, threshold, name=None):
def _multi_class_head_with_softmax_cross_entropy_loss(n_classes, def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
weight_feature_key=None, weight_column=None,
label_vocabulary=None): label_vocabulary=None):
"""Creates a '_Head' for multi class classification. """Creates a '_Head' for multi class classification.
@ -287,7 +288,8 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
Args: Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`). `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
weight_feature_key: A string defining feature column name representing weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. will be multiplied by the loss of the example.
label_vocabulary: A list of strings represents possible label values. If it label_vocabulary: A list of strings represents possible label values. If it
@ -307,18 +309,18 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
raise ValueError('label_vocabulary should be a list. Given type: {}'.format( raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
type(label_vocabulary))) type(label_vocabulary)))
return _MultiClassHeadWithSoftmaxCrossEntropyLoss( return _MultiClassHeadWithSoftmaxCrossEntropyLoss(n_classes, weight_column,
n_classes, weight_feature_key, label_vocabulary) label_vocabulary)
class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head): class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
"""See `_multi_class_head_with_softmax_cross_entropy_loss`.""" """See `_multi_class_head_with_softmax_cross_entropy_loss`."""
def __init__(self, n_classes, weight_feature_key=None, label_vocabulary=None): def __init__(self, n_classes, weight_column=None, label_vocabulary=None):
if (n_classes is None) or (n_classes <= 2): if (n_classes is None) or (n_classes <= 2):
raise ValueError('n_classes must be > 2: %s.' % n_classes) raise ValueError('n_classes must be > 2: %s.' % n_classes)
self._n_classes = n_classes self._n_classes = n_classes
self._weight_feature_key = weight_feature_key self._weight_column = weight_column
self._label_vocabulary = label_vocabulary self._label_vocabulary = label_vocabulary
@property @property
@ -417,10 +419,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE) labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
# Restore the squeezed dim, so unweighted_loss matches the weights shape. # Restore the squeezed dim, so unweighted_loss matches the weights shape.
unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=(1,)) unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=(1,))
weights = ( weights = _weights(features, self._weight_column)
1. if (self._weight_feature_key is None) else
features[self._weight_feature_key])
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
training_loss = losses.compute_weighted_loss( training_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
if mode == model_fn.ModeKeys.EVAL: if mode == model_fn.ModeKeys.EVAL:
@ -453,7 +452,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss( def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_feature_key=None, thresholds=None, label_vocabulary=None): weight_column=None, thresholds=None, label_vocabulary=None):
"""Creates a `Head` for single label binary classification. """Creates a `Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss. This head uses `sigmoid_cross_entropy_with_logits` loss.
@ -461,7 +460,8 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
This head expects to be fed float labels of shape `(batch_size, 1)`. This head expects to be fed float labels of shape `(batch_size, 1)`.
Args: Args:
weight_feature_key: A string defining feature column name representing weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. will be multiplied by the loss of the example.
thresholds: Iterable of floats in the range `(0, 1)`. For binary thresholds: Iterable of floats in the range `(0, 1)`. For binary
@ -491,7 +491,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
if (threshold <= 0.0) or (threshold >= 1.0): if (threshold <= 0.0) or (threshold >= 1.0):
raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,)) raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,))
return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss( return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
weight_feature_key=weight_feature_key, weight_column=weight_column,
thresholds=thresholds, thresholds=thresholds,
label_vocabulary=label_vocabulary) label_vocabulary=label_vocabulary)
@ -499,11 +499,9 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
"""See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`.""" """See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`."""
def __init__(self, def __init__(self, weight_column=None, thresholds=None,
weight_feature_key=None,
thresholds=None,
label_vocabulary=None): label_vocabulary=None):
self._weight_feature_key = weight_feature_key self._weight_column = weight_column
self._thresholds = thresholds self._thresholds = thresholds
self._label_vocabulary = label_vocabulary self._label_vocabulary = label_vocabulary
@ -624,10 +622,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
labels = _assert_range(labels, 2) labels = _assert_range(labels, 2)
unweighted_loss = nn.sigmoid_cross_entropy_with_logits( unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits, name='loss') labels=labels, logits=logits, name='loss')
weights = ( weights = _weights(features, self._weight_column)
1. if (self._weight_feature_key is None) else
features[self._weight_feature_key])
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
training_loss = losses.compute_weighted_loss( training_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
if mode == model_fn.ModeKeys.EVAL: if mode == model_fn.ModeKeys.EVAL:
@ -660,13 +655,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
train_op=train_op_fn(training_loss)) train_op=train_op_fn(training_loss))
def _regression_head_with_mean_squared_error_loss( def _regression_head_with_mean_squared_error_loss(weight_column=None,
weight_feature_key=None, label_dimension=1):
label_dimension=1):
"""Creates a `_Head` for regression using the mean squared loss. """Creates a `_Head` for regression using the mean squared loss.
Args: Args:
weight_feature_key: A string defining feature column name representing weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. will be multiplied by the loss of the example.
label_dimension: Number of regression labels per example. This is the size label_dimension: Number of regression labels per example. This is the size
@ -677,33 +672,18 @@ def _regression_head_with_mean_squared_error_loss(
An instance of `_Head` for linear regression. An instance of `_Head` for linear regression.
""" """
return _RegressionHeadWithMeanSquaredErrorLoss( return _RegressionHeadWithMeanSquaredErrorLoss(
weight_feature_key=weight_feature_key, weight_column=weight_column, label_dimension=label_dimension)
label_dimension=label_dimension)
class _RegressionHeadWithMeanSquaredErrorLoss(_Head): class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
"""`Head` for regression using the mean squared loss.""" """`Head` for regression using the mean squared loss."""
def __init__(self, def __init__(self, label_dimension, weight_column=None):
label_dimension, """`Head` for regression."""
weight_feature_key=None):
"""`Head` for regression.
Args:
label_dimension: Number of regression labels per example. This is the
size of the last dimension of the labels `Tensor` (typically, this has
shape `[batch_size, label_dimension]`).
weight_feature_key: A string defining feature column name representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
Raises:
ValueError: if `label_dimension` < 1.
"""
if label_dimension < 1: if label_dimension < 1:
raise ValueError('Invalid label_dimension %s.' % label_dimension) raise ValueError('Invalid label_dimension %s.' % label_dimension)
self._logits_dimension = label_dimension self._logits_dimension = label_dimension
self._weight_feature_key = weight_feature_key self._weight_column = weight_column
@property @property
def logits_dimension(self): def logits_dimension(self):
@ -731,10 +711,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
self._logits_dimension) self._logits_dimension)
unweighted_loss = losses.mean_squared_error( unweighted_loss = losses.mean_squared_error(
labels=labels, predictions=logits, reduction=losses.Reduction.NONE) labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
weights = ( weights = _weights(features, self._weight_column)
1. if (self._weight_feature_key is None) else
features[self._weight_feature_key])
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
training_loss = losses.compute_weighted_loss( training_loss = losses.compute_weighted_loss(
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
if mode == model_fn.ModeKeys.EVAL: if mode == model_fn.ModeKeys.EVAL:
@ -774,3 +751,21 @@ def _assert_range(labels, n_classes):
labels, message='Label IDs must >= 0') labels, message='Label IDs must >= 0')
with ops.control_dependencies((assert_less, assert_greater)): with ops.control_dependencies((assert_less, assert_greater)):
return array_ops.identity(labels) return array_ops.identity(labels)
def _weights(features, weight_column):
"""Fetches weights from features."""
if weight_column is None:
return 1.
if isinstance(weight_column, six.string_types):
weight_column = feature_column_lib.numeric_column(key=weight_column)
if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access
raise TypeError('Weight column must be either a string or _NumericColumn. '
'Given type: {}.'.format(type(weight_column)))
weights = weight_column._get_dense_tensor( # pylint: disable=protected-access
feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access
if not (weights.dtype.is_floating or weights.dtype.is_integer):
raise ValueError('Weight column should be castable to float. '
'Given dtype: {}'.format(weights.dtype))
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
return weights

View File

@ -27,6 +27,7 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -338,7 +339,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
expected_probabilities = [[0.576117, 0.2119416, 0.2119416], expected_probabilities = [[0.576117, 0.2119416, 0.2119416],
[0.2119416, 0.2119416, 0.576117]] [0.2119416, 0.2119416, 0.576117]]
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes, weight_feature_key='label_weights') n_classes, weight_column='label_weights')
weights_2x1 = [[1.], [2.]] weights_2x1 = [[1.], [2.]]
spec = head.create_estimator_spec( spec = head.create_estimator_spec(
@ -440,7 +441,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
def test_weighted_multi_example_eval(self): def test_weighted_multi_example_eval(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes, weight_feature_key='label_weights') n_classes, weight_column='label_weights')
# Create estimator spec. # Create estimator spec.
logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32) logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)
@ -534,7 +535,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
def test_train_with_one_dim_label_and_weights(self): def test_train_with_one_dim_label_and_weights(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes, weight_feature_key='label_weights') n_classes, weight_column='label_weights')
logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32) logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)
labels_rank_1 = np.array((1, 2, 2,), dtype=np.int64) labels_rank_1 = np.array((1, 2, 2,), dtype=np.int64)
@ -616,7 +617,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
def test_weighted_multi_example_train(self): def test_weighted_multi_example_train(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes, weight_feature_key='label_weights') n_classes, weight_column='label_weights')
# Create estimator spec. # Create estimator spec.
logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32) logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)
@ -985,7 +986,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
def test_weighted_multi_example_predict(self): def test_weighted_multi_example_predict(self):
"""3 examples, 1 batch.""" """3 examples, 1 batch."""
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_feature_key='label_weights') weight_column='label_weights')
# Create estimator spec. # Create estimator spec.
logits = np.array(((45,), (-41,), (44,)), dtype=np.int32) logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)
@ -1018,7 +1019,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
def test_weighted_multi_example_eval(self): def test_weighted_multi_example_eval(self):
"""3 examples, 1 batch.""" """3 examples, 1 batch."""
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_feature_key='label_weights') weight_column='label_weights')
# Create estimator spec. # Create estimator spec.
logits = np.array(((45,), (-41,), (44,)), dtype=np.int32) logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)
@ -1072,7 +1073,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
def test_train_with_one_dim_labels_and_weights(self): def test_train_with_one_dim_labels_and_weights(self):
"""3 examples, 1 batch.""" """3 examples, 1 batch."""
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_feature_key='label_weights') weight_column='label_weights')
# Create estimator spec. # Create estimator spec.
logits = np.array(((45,), (-41,), (44,)), dtype=np.float32) logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)
@ -1123,7 +1124,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
def test_weighted_multi_example_train(self): def test_weighted_multi_example_train(self):
"""3 examples, 1 batch.""" """3 examples, 1 batch."""
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
weight_feature_key='label_weights') weight_column='label_weights')
# Create estimator spec. # Create estimator spec.
logits = np.array(((45,), (-41,), (44,)), dtype=np.float32) logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)
@ -1403,7 +1404,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
def test_weighted_multi_example_eval(self): def test_weighted_multi_example_eval(self):
"""1d label, 3 examples, 1 batch.""" """1d label, 3 examples, 1 batch."""
head = head_lib._regression_head_with_mean_squared_error_loss( head = head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key='label_weights') weight_column='label_weights')
self.assertEqual(1, head.logits_dimension) self.assertEqual(1, head.logits_dimension)
# Create estimator spec. # Create estimator spec.
@ -1445,10 +1446,36 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
self.assertAllClose(expected_loss_mean, loss_mean) self.assertAllClose(expected_loss_mean, loss_mean)
self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval()) self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())
def test_weight_with_numeric_column(self):
"""1d label, 3 examples, 1 batch."""
head = head_lib._regression_head_with_mean_squared_error_loss(
weight_column=feature_column_lib.numeric_column(
'label_weights', normalizer_fn=lambda x: x + 1.))
# Create estimator spec.
logits = np.array(((45,), (41,), (44,)), dtype=np.int32)
spec = head.create_estimator_spec(
features={
'x':
np.array(((42,), (43,), (44,)), dtype=np.int32),
'label_weights':
np.array(((0.,), (-0.9,), (0.5,)), dtype=np.float32),
},
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=np.array(((35,), (42,), (45,)), dtype=np.int32))
# Assert loss.
with self.test_session() as sess:
_initialize_variables(self, spec.scaffold)
loss = sess.run(spec.loss)
# loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6
self.assertAllClose(101.6, loss)
def test_weighted_multi_example_train(self): def test_weighted_multi_example_train(self):
"""1d label, 3 examples, 1 batch.""" """1d label, 3 examples, 1 batch."""
head = head_lib._regression_head_with_mean_squared_error_loss( head = head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key='label_weights') weight_column='label_weights')
self.assertEqual(1, head.logits_dimension) self.assertEqual(1, head.logits_dimension)
# Create estimator spec. # Create estimator spec.
@ -1500,7 +1527,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
def test_with_one_dim_label_and_weight(self): def test_with_one_dim_label_and_weight(self):
"""1d label, 3 examples, 1 batch.""" """1d label, 3 examples, 1 batch."""
head = head_lib._regression_head_with_mean_squared_error_loss( head = head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key='label_weights') weight_column='label_weights')
self.assertEqual(1, head.logits_dimension) self.assertEqual(1, head.logits_dimension)
# Create estimator spec. # Create estimator spec.
@ -1560,7 +1587,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
def test_weighted_multi_value_eval(self): def test_weighted_multi_value_eval(self):
"""3d label, 1 example, 1 batch.""" """3d label, 1 example, 1 batch."""
head = head_lib._regression_head_with_mean_squared_error_loss( head = head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key='label_weights', label_dimension=3) weight_column='label_weights', label_dimension=3)
self.assertEqual(3, head.logits_dimension) self.assertEqual(3, head.logits_dimension)
# Create estimator spec. # Create estimator spec.
@ -1605,7 +1632,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
def test_weighted_multi_value_train(self): def test_weighted_multi_value_train(self):
"""3d label, 1 example, 1 batch.""" """3d label, 1 example, 1 batch."""
head = head_lib._regression_head_with_mean_squared_error_loss( head = head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key='label_weights', label_dimension=3) weight_column='label_weights', label_dimension=3)
self.assertEqual(3, head.logits_dimension) self.assertEqual(3, head.logits_dimension)
# Create estimator spec. # Create estimator spec.
@ -1657,7 +1684,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
def test_weighted_multi_batch_eval(self): def test_weighted_multi_batch_eval(self):
"""1d label, 1 example, 3 batches.""" """1d label, 1 example, 3 batches."""
head = head_lib._regression_head_with_mean_squared_error_loss( head = head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key='label_weights') weight_column='label_weights')
self.assertEqual(1, head.logits_dimension) self.assertEqual(1, head.logits_dimension)
# Create estimator spec. # Create estimator spec.
@ -1723,7 +1750,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
def test_weighted_multi_batch_train(self): def test_weighted_multi_batch_train(self):
"""1d label, 1 example, 3 batches.""" """1d label, 1 example, 3 batches."""
head = head_lib._regression_head_with_mean_squared_error_loss( head = head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key='label_weights') weight_column='label_weights')
self.assertEqual(1, head.logits_dimension) self.assertEqual(1, head.logits_dimension)
# Create estimator spec. # Create estimator spec.

View File

@ -191,10 +191,10 @@ class LinearClassifier(estimator.Estimator):
""" """
if n_classes == 2: if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_feature_key=weight_feature_key) weight_column=weight_feature_key)
else: else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
n_classes, weight_feature_key=weight_feature_key) n_classes, weight_column=weight_feature_key)
super(LinearClassifier, self).__init__( super(LinearClassifier, self).__init__(
model_fn=_linear_model_fn, model_fn=_linear_model_fn,
model_dir=model_dir, model_dir=model_dir,
@ -284,11 +284,15 @@ class LinearRegressor(estimator.Estimator):
config=config, config=config,
params={ params={
# pylint: disable=protected-access # pylint: disable=protected-access
'head': head_lib._regression_head_with_mean_squared_error_loss( 'head':
label_dimension=label_dimension, head_lib._regression_head_with_mean_squared_error_loss(
weight_feature_key=weight_feature_key), label_dimension=label_dimension,
weight_column=weight_feature_key),
# pylint: enable=protected-access # pylint: enable=protected-access
'feature_columns': feature_columns, 'feature_columns':
'optimizer': optimizer, feature_columns,
'partitioner': partitioner, 'optimizer':
optimizer,
'partitioner':
partitioner,
}) })