Add weight-column support to the heads.
PiperOrigin-RevId: 158409180
This commit is contained in:
parent
7fb52cd54c
commit
d35cbbb447
tensorflow/python/estimator
@ -394,6 +394,7 @@ py_library(
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:weights_broadcast_ops",
|
||||
"//tensorflow/python/feature_column",
|
||||
"//tensorflow/python/ops/losses",
|
||||
],
|
||||
)
|
||||
|
@ -230,10 +230,10 @@ class DNNClassifier(estimator.Estimator):
|
||||
"""
|
||||
if n_classes == 2:
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
|
||||
weight_feature_key=weight_feature_key)
|
||||
weight_column=weight_feature_key)
|
||||
else:
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
|
||||
n_classes, weight_feature_key=weight_feature_key)
|
||||
n_classes, weight_column=weight_feature_key)
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_model_fn(
|
||||
features=features,
|
||||
@ -351,9 +351,10 @@ class DNNRegressor(estimator.Estimator):
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
head=head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access
|
||||
head=head_lib. # pylint: disable=protected-access
|
||||
_regression_head_with_mean_squared_error_loss(
|
||||
label_dimension=label_dimension,
|
||||
weight_feature_key=weight_feature_key),
|
||||
weight_column=weight_feature_key),
|
||||
hidden_units=hidden_units,
|
||||
feature_columns=tuple(feature_columns or []),
|
||||
optimizer=optimizer,
|
||||
|
@ -307,6 +307,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
dnn_activation_fn=nn.relu,
|
||||
dnn_dropout=None,
|
||||
n_classes=2,
|
||||
weight_feature_key=None,
|
||||
input_layer_partitioner=None,
|
||||
config=None):
|
||||
"""Initializes a DNNLinearCombinedClassifier instance.
|
||||
@ -333,6 +334,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
a given coordinate.
|
||||
n_classes: Number of label classes. Defaults to 2, namely binary
|
||||
classification. Must be > 1.
|
||||
weight_feature_key: A string defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
input_layer_partitioner: Partitioner for input layer. Defaults to
|
||||
`min_max_variable_partitioner` with `min_slice_size` 64 << 20.
|
||||
config: RunConfig object to configure the runtime settings.
|
||||
@ -348,11 +352,12 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
raise ValueError('Either linear_feature_columns or dnn_feature_columns '
|
||||
'must be defined.')
|
||||
if n_classes == 2:
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss() # pylint: disable=protected-access
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
|
||||
weight_column=weight_feature_key)
|
||||
else:
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
|
||||
n_classes)
|
||||
|
||||
n_classes,
|
||||
weight_column=weight_feature_key)
|
||||
def _model_fn(features, labels, mode, config):
|
||||
return _dnn_linear_combined_model_fn(
|
||||
features=features,
|
||||
@ -500,7 +505,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
|
||||
head=head_lib. # pylint: disable=protected-access
|
||||
_regression_head_with_mean_squared_error_loss(
|
||||
label_dimension=label_dimension,
|
||||
weight_feature_key=weight_feature_key),
|
||||
weight_column=weight_feature_key),
|
||||
linear_feature_columns=linear_feature_columns,
|
||||
linear_optimizer=linear_optimizer,
|
||||
dnn_feature_columns=dnn_feature_columns,
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.estimator import model_fn
|
||||
from tensorflow.python.estimator.canned import metric_keys
|
||||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.estimator.export import export_output
|
||||
from tensorflow.python.feature_column import feature_column as feature_column_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -278,7 +279,7 @@ def _recall_at_threshold(labels, predictions, weights, threshold, name=None):
|
||||
|
||||
|
||||
def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
|
||||
weight_feature_key=None,
|
||||
weight_column=None,
|
||||
label_vocabulary=None):
|
||||
"""Creates a '_Head' for multi class classification.
|
||||
|
||||
@ -287,7 +288,8 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
|
||||
Args:
|
||||
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
|
||||
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
|
||||
weight_feature_key: A string defining feature column name representing
|
||||
weight_column: A string or a `_NumericColumn` created by
|
||||
`tf.feature_column.numeric_column` defining feature column representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
label_vocabulary: A list of strings represents possible label values. If it
|
||||
@ -307,18 +309,18 @@ def _multi_class_head_with_softmax_cross_entropy_loss(n_classes,
|
||||
raise ValueError('label_vocabulary should be a list. Given type: {}'.format(
|
||||
type(label_vocabulary)))
|
||||
|
||||
return _MultiClassHeadWithSoftmaxCrossEntropyLoss(
|
||||
n_classes, weight_feature_key, label_vocabulary)
|
||||
return _MultiClassHeadWithSoftmaxCrossEntropyLoss(n_classes, weight_column,
|
||||
label_vocabulary)
|
||||
|
||||
|
||||
class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
|
||||
"""See `_multi_class_head_with_softmax_cross_entropy_loss`."""
|
||||
|
||||
def __init__(self, n_classes, weight_feature_key=None, label_vocabulary=None):
|
||||
def __init__(self, n_classes, weight_column=None, label_vocabulary=None):
|
||||
if (n_classes is None) or (n_classes <= 2):
|
||||
raise ValueError('n_classes must be > 2: %s.' % n_classes)
|
||||
self._n_classes = n_classes
|
||||
self._weight_feature_key = weight_feature_key
|
||||
self._weight_column = weight_column
|
||||
self._label_vocabulary = label_vocabulary
|
||||
|
||||
@property
|
||||
@ -417,10 +419,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
|
||||
labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
|
||||
# Restore the squeezed dim, so unweighted_loss matches the weights shape.
|
||||
unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=(1,))
|
||||
weights = (
|
||||
1. if (self._weight_feature_key is None) else
|
||||
features[self._weight_feature_key])
|
||||
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
|
||||
weights = _weights(features, self._weight_column)
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
if mode == model_fn.ModeKeys.EVAL:
|
||||
@ -453,7 +452,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
|
||||
|
||||
|
||||
def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
weight_feature_key=None, thresholds=None, label_vocabulary=None):
|
||||
weight_column=None, thresholds=None, label_vocabulary=None):
|
||||
"""Creates a `Head` for single label binary classification.
|
||||
|
||||
This head uses `sigmoid_cross_entropy_with_logits` loss.
|
||||
@ -461,7 +460,8 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
This head expects to be fed float labels of shape `(batch_size, 1)`.
|
||||
|
||||
Args:
|
||||
weight_feature_key: A string defining feature column name representing
|
||||
weight_column: A string or a `_NumericColumn` created by
|
||||
`tf.feature_column.numeric_column` defining feature column representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
thresholds: Iterable of floats in the range `(0, 1)`. For binary
|
||||
@ -491,7 +491,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
if (threshold <= 0.0) or (threshold >= 1.0):
|
||||
raise ValueError('thresholds not in (0, 1): %s.' % (thresholds,))
|
||||
return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
|
||||
weight_feature_key=weight_feature_key,
|
||||
weight_column=weight_column,
|
||||
thresholds=thresholds,
|
||||
label_vocabulary=label_vocabulary)
|
||||
|
||||
@ -499,11 +499,9 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
|
||||
"""See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`."""
|
||||
|
||||
def __init__(self,
|
||||
weight_feature_key=None,
|
||||
thresholds=None,
|
||||
def __init__(self, weight_column=None, thresholds=None,
|
||||
label_vocabulary=None):
|
||||
self._weight_feature_key = weight_feature_key
|
||||
self._weight_column = weight_column
|
||||
self._thresholds = thresholds
|
||||
self._label_vocabulary = label_vocabulary
|
||||
|
||||
@ -624,10 +622,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
|
||||
labels = _assert_range(labels, 2)
|
||||
unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=labels, logits=logits, name='loss')
|
||||
weights = (
|
||||
1. if (self._weight_feature_key is None) else
|
||||
features[self._weight_feature_key])
|
||||
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
|
||||
weights = _weights(features, self._weight_column)
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
if mode == model_fn.ModeKeys.EVAL:
|
||||
@ -660,13 +655,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
|
||||
train_op=train_op_fn(training_loss))
|
||||
|
||||
|
||||
def _regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key=None,
|
||||
label_dimension=1):
|
||||
def _regression_head_with_mean_squared_error_loss(weight_column=None,
|
||||
label_dimension=1):
|
||||
"""Creates a `_Head` for regression using the mean squared loss.
|
||||
|
||||
Args:
|
||||
weight_feature_key: A string defining feature column name representing
|
||||
weight_column: A string or a `_NumericColumn` created by
|
||||
`tf.feature_column.numeric_column` defining feature column representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
label_dimension: Number of regression labels per example. This is the size
|
||||
@ -677,33 +672,18 @@ def _regression_head_with_mean_squared_error_loss(
|
||||
An instance of `_Head` for linear regression.
|
||||
"""
|
||||
return _RegressionHeadWithMeanSquaredErrorLoss(
|
||||
weight_feature_key=weight_feature_key,
|
||||
label_dimension=label_dimension)
|
||||
weight_column=weight_column, label_dimension=label_dimension)
|
||||
|
||||
|
||||
class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
|
||||
"""`Head` for regression using the mean squared loss."""
|
||||
|
||||
def __init__(self,
|
||||
label_dimension,
|
||||
weight_feature_key=None):
|
||||
"""`Head` for regression.
|
||||
|
||||
Args:
|
||||
label_dimension: Number of regression labels per example. This is the
|
||||
size of the last dimension of the labels `Tensor` (typically, this has
|
||||
shape `[batch_size, label_dimension]`).
|
||||
weight_feature_key: A string defining feature column name representing
|
||||
weights. It is used to down weight or boost examples during training. It
|
||||
will be multiplied by the loss of the example.
|
||||
|
||||
Raises:
|
||||
ValueError: if `label_dimension` < 1.
|
||||
"""
|
||||
def __init__(self, label_dimension, weight_column=None):
|
||||
"""`Head` for regression."""
|
||||
if label_dimension < 1:
|
||||
raise ValueError('Invalid label_dimension %s.' % label_dimension)
|
||||
self._logits_dimension = label_dimension
|
||||
self._weight_feature_key = weight_feature_key
|
||||
self._weight_column = weight_column
|
||||
|
||||
@property
|
||||
def logits_dimension(self):
|
||||
@ -731,10 +711,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
|
||||
self._logits_dimension)
|
||||
unweighted_loss = losses.mean_squared_error(
|
||||
labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
|
||||
weights = (
|
||||
1. if (self._weight_feature_key is None) else
|
||||
features[self._weight_feature_key])
|
||||
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
|
||||
weights = _weights(features, self._weight_column)
|
||||
training_loss = losses.compute_weighted_loss(
|
||||
unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
|
||||
if mode == model_fn.ModeKeys.EVAL:
|
||||
@ -774,3 +751,21 @@ def _assert_range(labels, n_classes):
|
||||
labels, message='Label IDs must >= 0')
|
||||
with ops.control_dependencies((assert_less, assert_greater)):
|
||||
return array_ops.identity(labels)
|
||||
|
||||
|
||||
def _weights(features, weight_column):
|
||||
"""Fetches weights from features."""
|
||||
if weight_column is None:
|
||||
return 1.
|
||||
if isinstance(weight_column, six.string_types):
|
||||
weight_column = feature_column_lib.numeric_column(key=weight_column)
|
||||
if not isinstance(weight_column, feature_column_lib._NumericColumn): # pylint: disable=protected-access
|
||||
raise TypeError('Weight column must be either a string or _NumericColumn. '
|
||||
'Given type: {}.'.format(type(weight_column)))
|
||||
weights = weight_column._get_dense_tensor( # pylint: disable=protected-access
|
||||
feature_column_lib._LazyBuilder(features)) # pylint: disable=protected-access
|
||||
if not (weights.dtype.is_floating or weights.dtype.is_integer):
|
||||
raise ValueError('Weight column should be castable to float. '
|
||||
'Given dtype: {}'.format(weights.dtype))
|
||||
weights = _maybe_expand_dim(math_ops.to_float(weights, name='weights'))
|
||||
return weights
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.estimator.canned import head as head_lib
|
||||
from tensorflow.python.estimator.canned import metric_keys
|
||||
from tensorflow.python.estimator.canned import prediction_keys
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.feature_column import feature_column as feature_column_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -338,7 +339,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
|
||||
expected_probabilities = [[0.576117, 0.2119416, 0.2119416],
|
||||
[0.2119416, 0.2119416, 0.576117]]
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes, weight_feature_key='label_weights')
|
||||
n_classes, weight_column='label_weights')
|
||||
|
||||
weights_2x1 = [[1.], [2.]]
|
||||
spec = head.create_estimator_spec(
|
||||
@ -440,7 +441,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
|
||||
def test_weighted_multi_example_eval(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes, weight_feature_key='label_weights')
|
||||
n_classes, weight_column='label_weights')
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)
|
||||
@ -534,7 +535,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
|
||||
def test_train_with_one_dim_label_and_weights(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes, weight_feature_key='label_weights')
|
||||
n_classes, weight_column='label_weights')
|
||||
|
||||
logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)
|
||||
labels_rank_1 = np.array((1, 2, 2,), dtype=np.int64)
|
||||
@ -616,7 +617,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
|
||||
def test_weighted_multi_example_train(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
|
||||
n_classes, weight_feature_key='label_weights')
|
||||
n_classes, weight_column='label_weights')
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)
|
||||
@ -985,7 +986,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
def test_weighted_multi_example_predict(self):
|
||||
"""3 examples, 1 batch."""
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)
|
||||
@ -1018,7 +1019,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
def test_weighted_multi_example_eval(self):
|
||||
"""3 examples, 1 batch."""
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)
|
||||
@ -1072,7 +1073,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
def test_train_with_one_dim_labels_and_weights(self):
|
||||
"""3 examples, 1 batch."""
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)
|
||||
@ -1123,7 +1124,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
|
||||
def test_weighted_multi_example_train(self):
|
||||
"""3 examples, 1 batch."""
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)
|
||||
@ -1403,7 +1404,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
|
||||
def test_weighted_multi_example_eval(self):
|
||||
"""1d label, 3 examples, 1 batch."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
self.assertEqual(1, head.logits_dimension)
|
||||
|
||||
# Create estimator spec.
|
||||
@ -1445,10 +1446,36 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
|
||||
self.assertAllClose(expected_loss_mean, loss_mean)
|
||||
self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())
|
||||
|
||||
def test_weight_with_numeric_column(self):
|
||||
"""1d label, 3 examples, 1 batch."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_column=feature_column_lib.numeric_column(
|
||||
'label_weights', normalizer_fn=lambda x: x + 1.))
|
||||
|
||||
# Create estimator spec.
|
||||
logits = np.array(((45,), (41,), (44,)), dtype=np.int32)
|
||||
spec = head.create_estimator_spec(
|
||||
features={
|
||||
'x':
|
||||
np.array(((42,), (43,), (44,)), dtype=np.int32),
|
||||
'label_weights':
|
||||
np.array(((0.,), (-0.9,), (0.5,)), dtype=np.float32),
|
||||
},
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
logits=logits,
|
||||
labels=np.array(((35,), (42,), (45,)), dtype=np.int32))
|
||||
|
||||
# Assert loss.
|
||||
with self.test_session() as sess:
|
||||
_initialize_variables(self, spec.scaffold)
|
||||
loss = sess.run(spec.loss)
|
||||
# loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6
|
||||
self.assertAllClose(101.6, loss)
|
||||
|
||||
def test_weighted_multi_example_train(self):
|
||||
"""1d label, 3 examples, 1 batch."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
self.assertEqual(1, head.logits_dimension)
|
||||
|
||||
# Create estimator spec.
|
||||
@ -1500,7 +1527,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
|
||||
def test_with_one_dim_label_and_weight(self):
|
||||
"""1d label, 3 examples, 1 batch."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
self.assertEqual(1, head.logits_dimension)
|
||||
|
||||
# Create estimator spec.
|
||||
@ -1560,7 +1587,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
|
||||
def test_weighted_multi_value_eval(self):
|
||||
"""3d label, 1 example, 1 batch."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key='label_weights', label_dimension=3)
|
||||
weight_column='label_weights', label_dimension=3)
|
||||
self.assertEqual(3, head.logits_dimension)
|
||||
|
||||
# Create estimator spec.
|
||||
@ -1605,7 +1632,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
|
||||
def test_weighted_multi_value_train(self):
|
||||
"""3d label, 1 example, 1 batch."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key='label_weights', label_dimension=3)
|
||||
weight_column='label_weights', label_dimension=3)
|
||||
self.assertEqual(3, head.logits_dimension)
|
||||
|
||||
# Create estimator spec.
|
||||
@ -1657,7 +1684,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
|
||||
def test_weighted_multi_batch_eval(self):
|
||||
"""1d label, 1 example, 3 batches."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
self.assertEqual(1, head.logits_dimension)
|
||||
|
||||
# Create estimator spec.
|
||||
@ -1723,7 +1750,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
|
||||
def test_weighted_multi_batch_train(self):
|
||||
"""1d label, 1 example, 3 batches."""
|
||||
head = head_lib._regression_head_with_mean_squared_error_loss(
|
||||
weight_feature_key='label_weights')
|
||||
weight_column='label_weights')
|
||||
self.assertEqual(1, head.logits_dimension)
|
||||
|
||||
# Create estimator spec.
|
||||
|
@ -191,10 +191,10 @@ class LinearClassifier(estimator.Estimator):
|
||||
"""
|
||||
if n_classes == 2:
|
||||
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
|
||||
weight_feature_key=weight_feature_key)
|
||||
weight_column=weight_feature_key)
|
||||
else:
|
||||
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
|
||||
n_classes, weight_feature_key=weight_feature_key)
|
||||
n_classes, weight_column=weight_feature_key)
|
||||
super(LinearClassifier, self).__init__(
|
||||
model_fn=_linear_model_fn,
|
||||
model_dir=model_dir,
|
||||
@ -284,11 +284,15 @@ class LinearRegressor(estimator.Estimator):
|
||||
config=config,
|
||||
params={
|
||||
# pylint: disable=protected-access
|
||||
'head': head_lib._regression_head_with_mean_squared_error_loss(
|
||||
label_dimension=label_dimension,
|
||||
weight_feature_key=weight_feature_key),
|
||||
'head':
|
||||
head_lib._regression_head_with_mean_squared_error_loss(
|
||||
label_dimension=label_dimension,
|
||||
weight_column=weight_feature_key),
|
||||
# pylint: enable=protected-access
|
||||
'feature_columns': feature_columns,
|
||||
'optimizer': optimizer,
|
||||
'partitioner': partitioner,
|
||||
'feature_columns':
|
||||
feature_columns,
|
||||
'optimizer':
|
||||
optimizer,
|
||||
'partitioner':
|
||||
partitioner,
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user