Makes the head/multi-head API public and updates selected users while leaving other users using aliases.
This is CL#1 of a series of CLs to make head/multi-head API public and migrate all users. This CL, + Makes the Head interface and factory functions public. + Updates all tf-learn internal and SIR usage. + Leaves aliases for the legacy private names which will be removed with all existing usages in next CLs. + Also, updates documentation. Change: 149613397
This commit is contained in:
parent
5b3e560d2f
commit
58067591b6
@ -28,13 +28,24 @@ See the @{$python/contrib.learn} guide.
|
||||
@@MetricSpec
|
||||
@@PredictionKey
|
||||
@@DNNClassifier
|
||||
@@DNNEstimator
|
||||
@@DNNRegressor
|
||||
@@DNNLinearCombinedRegressor
|
||||
@@DNNLinearCombinedClassifier
|
||||
@@LinearClassifier
|
||||
@@LinearEstimator
|
||||
@@LinearRegressor
|
||||
@@LogisticRegressor
|
||||
|
||||
@@Head
|
||||
@@multi_class_head
|
||||
@@multi_label_head
|
||||
@@binary_svm_head
|
||||
@@regression_head
|
||||
@@poisson_regression_head
|
||||
@@multi_head
|
||||
@@no_op_train_fn
|
||||
|
||||
@@Experiment
|
||||
@@ExportStrategy
|
||||
@@TaskType
|
||||
|
@ -295,6 +295,7 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||
from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedRegressor
|
||||
@ -304,8 +305,17 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import no_op_train_fn
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import poisson_regression_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.head import regression_head
|
||||
from tensorflow.contrib.learn.python.learn.estimators.kmeans import KMeansClustering
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
|
||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||
|
@ -131,7 +131,7 @@ class ComposableModelTest(test.TestCase):
|
||||
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
||||
age = feature_column.real_valued_column('age')
|
||||
|
||||
head = head_lib._multi_class_head(n_classes=2)
|
||||
head = head_lib.multi_class_head(n_classes=2)
|
||||
classifier = _linear_estimator(head, feature_columns=[age, language])
|
||||
|
||||
classifier.fit(input_fn=input_fn, steps=1000)
|
||||
@ -157,7 +157,7 @@ class ComposableModelTest(test.TestCase):
|
||||
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
||||
age = feature_column.sparse_column_with_hash_bucket('age', 2)
|
||||
|
||||
head = head_lib._multi_class_head(n_classes=2)
|
||||
head = head_lib.multi_class_head(n_classes=2)
|
||||
classifier = _joint_linear_estimator(head, feature_columns=[age, language])
|
||||
|
||||
classifier.fit(input_fn=input_fn, steps=1000)
|
||||
@ -171,7 +171,7 @@ class ComposableModelTest(test.TestCase):
|
||||
"""Tests multi-class classification using matrix data as input."""
|
||||
cont_features = [feature_column.real_valued_column('feature', dimension=4)]
|
||||
|
||||
head = head_lib._multi_class_head(n_classes=3)
|
||||
head = head_lib.multi_class_head(n_classes=3)
|
||||
classifier = _dnn_estimator(
|
||||
head, feature_columns=cont_features, hidden_units=[3, 3])
|
||||
|
||||
|
@ -304,7 +304,7 @@ class DNNClassifier(estimator.Estimator):
|
||||
config=config,
|
||||
params={
|
||||
"head":
|
||||
head_lib._multi_class_head( # pylint: disable=protected-access
|
||||
head_lib.multi_class_head(
|
||||
n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias),
|
||||
@ -579,7 +579,7 @@ class DNNRegressor(estimator.Estimator):
|
||||
config=config,
|
||||
params={
|
||||
"head":
|
||||
head_lib._regression_head( # pylint: disable=protected-access
|
||||
head_lib.regression_head(
|
||||
label_dimension=label_dimension,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias),
|
||||
@ -731,8 +731,7 @@ class DNNRegressor(estimator.Estimator):
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
|
||||
# TODO(zakaria): Make it public when b/34751732 is fixed.
|
||||
class _DNNEstimator(estimator.Estimator):
|
||||
class DNNEstimator(estimator.Estimator):
|
||||
"""A Estimator for TensorFlow DNN models with user specified _Head.
|
||||
|
||||
Example:
|
||||
@ -745,20 +744,20 @@ class _DNNEstimator(estimator.Estimator):
|
||||
...)
|
||||
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
|
||||
...)
|
||||
To create a _DNNEstimator for binary classification, where
|
||||
estimator = _DNNEstimator(
|
||||
To create a DNNEstimator for binary classification, where
|
||||
estimator = DNNEstimator(
|
||||
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
||||
head=head_lib._multi_class__head(n_classes=2),
|
||||
head=tf.contrib.learn.multi_class_head(n_classes=2),
|
||||
hidden_units=[1024, 512, 256])
|
||||
|
||||
If your label is keyed with "y" in your labels dict, and weights are keyed
|
||||
with "w" in features dict, and you want to enable centered bias,
|
||||
head = head_lib._multi_class__head(
|
||||
head = tf.contrib.learn.multi_class_head(
|
||||
n_classes=2,
|
||||
label_name="x",
|
||||
weight_column_name="w",
|
||||
enable_centered_bias=True)
|
||||
estimator = _DNNEstimator(
|
||||
estimator = DNNEstimator(
|
||||
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
||||
head=head,
|
||||
hidden_units=[1024, 512, 256])
|
||||
@ -802,10 +801,10 @@ class _DNNEstimator(estimator.Estimator):
|
||||
feature_engineering_fn=None,
|
||||
embedding_lr_multipliers=None,
|
||||
input_layer_min_slice_size=None):
|
||||
"""Initializes a _DNNEstimator instance.
|
||||
"""Initializes a `DNNEstimator` instance.
|
||||
|
||||
Args:
|
||||
head: _Head instance.
|
||||
head: `Head` instance.
|
||||
hidden_units: List of hidden units per layer. All layers are fully
|
||||
connected. Ex. `[64, 32]` means first layer has 64 nodes and second one
|
||||
has 32.
|
||||
@ -836,9 +835,9 @@ class _DNNEstimator(estimator.Estimator):
|
||||
partitions. If not provided, will use the default of 64M.
|
||||
|
||||
Returns:
|
||||
A `_DNNEstimator` estimator.
|
||||
A `DNNEstimator` estimator.
|
||||
"""
|
||||
super(_DNNEstimator, self).__init__(
|
||||
super(DNNEstimator, self).__init__(
|
||||
model_fn=_dnn_model_fn,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
@ -854,4 +853,3 @@ class _DNNEstimator(estimator.Estimator):
|
||||
"input_layer_min_slice_size": input_layer_min_slice_size,
|
||||
},
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
|
@ -550,7 +550,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
if not self._feature_columns:
|
||||
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||
"must be defined.")
|
||||
head = head_lib._multi_class_head( # pylint: disable=protected-access
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
@ -841,7 +841,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
|
||||
if not self._feature_columns:
|
||||
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||
"must be defined.")
|
||||
head = head_lib._regression_head( # pylint: disable=protected-access
|
||||
head = head_lib.regression_head(
|
||||
weight_column_name=weight_column_name,
|
||||
label_dimension=label_dimension,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
|
@ -60,7 +60,7 @@ def _assert_metrics_in_range(keys, metrics):
|
||||
metrics)
|
||||
|
||||
|
||||
class _CheckCallsHead(head_lib._Head): # pylint: disable=protected-access
|
||||
class _CheckCallsHead(head_lib.Head):
|
||||
"""Head that checks whether head_ops is called."""
|
||||
|
||||
def __init__(self):
|
||||
@ -97,7 +97,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
||||
|
||||
params = {
|
||||
'dnn_feature_columns': [one_hot_language],
|
||||
'head': head_lib._multi_class_head(2),
|
||||
'head': head_lib.multi_class_head(2),
|
||||
'dnn_hidden_units': [1],
|
||||
# Set lr mult to 0. to keep embeddings constant.
|
||||
'embedding_lr_multipliers': {
|
||||
@ -131,7 +131,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
||||
|
||||
params = {
|
||||
'dnn_feature_columns': [embedding_language, embedding_wire],
|
||||
'head': head_lib._multi_class_head(2),
|
||||
'head': head_lib.multi_class_head(2),
|
||||
'dnn_hidden_units': [1],
|
||||
# Set lr mult to 0. to keep embeddings constant.
|
||||
'embedding_lr_multipliers': {
|
||||
|
@ -59,7 +59,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
||||
|
||||
params = {
|
||||
'feature_columns': [one_hot_language],
|
||||
'head': head_lib._multi_class_head(2),
|
||||
'head': head_lib.multi_class_head(2),
|
||||
'hidden_units': [1],
|
||||
# Set lr mult to 0. to keep embeddings constant.
|
||||
'embedding_lr_multipliers': {
|
||||
@ -90,7 +90,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
||||
|
||||
params = {
|
||||
'feature_columns': [embedding_language, embedding_wire],
|
||||
'head': head_lib._multi_class_head(2),
|
||||
'head': head_lib.multi_class_head(2),
|
||||
'hidden_units': [1],
|
||||
# Set lr mult to 0. to keep embeddings constant.
|
||||
'embedding_lr_multipliers': {
|
||||
@ -145,7 +145,7 @@ class DNNEstimatorTest(test.TestCase):
|
||||
exp.test()
|
||||
|
||||
def testEstimatorContract(self):
|
||||
estimator_test_utils.assert_estimator_contract(self, dnn._DNNEstimator)
|
||||
estimator_test_utils.assert_estimator_contract(self, dnn.DNNEstimator)
|
||||
|
||||
def testTrainWithWeights(self):
|
||||
"""Tests training with given weight column."""
|
||||
@ -172,8 +172,8 @@ class DNNEstimatorTest(test.TestCase):
|
||||
}
|
||||
return features, labels
|
||||
|
||||
dnn_estimator = dnn._DNNEstimator(
|
||||
head=head_lib._multi_class_head(2, weight_column_name='w'),
|
||||
dnn_estimator = dnn.DNNEstimator(
|
||||
head=head_lib.multi_class_head(2, weight_column_name='w'),
|
||||
feature_columns=[feature_column.real_valued_column('x')],
|
||||
hidden_units=[3, 3],
|
||||
config=run_config.RunConfig(tf_random_seed=1))
|
||||
|
@ -46,15 +46,141 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import training
|
||||
|
||||
# TODO(zakaria): add functions that creates a head and returns ModelOpFn
|
||||
|
||||
class Head(object):
|
||||
"""Interface for the head/top of a model.
|
||||
|
||||
Given logits (or output of a hidden layer), a Head knows how to compute
|
||||
predictions, loss, default metric and export signature. It is meant to,
|
||||
|
||||
1) Simplify writing model_fn and to make model_fn more configurable
|
||||
2) Support wide range of machine learning models. Since most heads can work
|
||||
with logits, they can support DNN, RNN, Wide, Wide&Deep,
|
||||
Global objectives, Gradient boosted trees and many other types
|
||||
of machine learning models.
|
||||
2) To allow users to seamlessly switch between 1 to n heads for multi
|
||||
objective learning (See _MultiHead implementation for more details)
|
||||
|
||||
Common usage:
|
||||
Here is simplified model_fn to build a multiclass DNN model.
|
||||
```python
|
||||
def _my_dnn_model_fn(features, labels, mode, params, config=None):
|
||||
# Optionally your callers can pass head to model_fn as a param.
|
||||
head = tf.contrib.learn.multi_class_head(...)
|
||||
input = tf.contrib.layers.input_from_feature_columns(features, ...)
|
||||
last_hidden_layer_out = tf.contrib.layers.stack(
|
||||
input, tf.contrib.layers.fully_connected, [1000, 500])
|
||||
logits = tf.contrib.layers.fully_connected(
|
||||
last_hidden_layer_out, head.logits_dimension, activation_fn=None)
|
||||
|
||||
def _train_op_fn(loss):
|
||||
return optimizer.minimize(loss)
|
||||
|
||||
return head.create_model_fn_ops(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits=logits,
|
||||
scope=...)
|
||||
```
|
||||
|
||||
Most heads also support logits_input which is typically the output of the last
|
||||
hidden layer. Some heads (like heads responsible for candidate sampling or
|
||||
hierarchical softmax) intrinsically will not support logits and you have
|
||||
to pass logits_input. Here is a common usage,
|
||||
```python
|
||||
return head.create_model_fn_ops(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
train_op_fn=_train_op_fn,
|
||||
logits_input=last_hidden_layer_out,
|
||||
scope=...)
|
||||
```python
|
||||
|
||||
There are cases where computing and applying gradients can not be meaningfully
|
||||
captured with train_op_fn we support (for example, with sync optimizer). In
|
||||
such case, you can take the responsibility on your own. Here is a common
|
||||
use case,
|
||||
```python
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features=features,
|
||||
labels=labels,
|
||||
mode=mode,
|
||||
train_op_fn=tf.contrib.learn.no_op_train_fn,
|
||||
logits=logits,
|
||||
scope=...)
|
||||
if mode == tf.contrib.learn.ModeKeys.TRAIN:
|
||||
optimizer = ...
|
||||
sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
|
||||
update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
|
||||
loss=model_fn_ops.loss, ...)
|
||||
hooks = [sync.make_session_run_hook(is_chief)]
|
||||
... upate train_op and hooks in ModelFnOps and return
|
||||
```
|
||||
"""
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractproperty
|
||||
def logits_dimension(self):
|
||||
"""Size of the last dimension of the logits `Tensor`.
|
||||
|
||||
Typically, logits is of shape `[batch_size, logits_dimension]`.
|
||||
|
||||
Returns:
|
||||
The expected size of the `logits` tensor.
|
||||
"""
|
||||
raise NotImplementedError("Calling an abstract method.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_model_fn_ops(self,
|
||||
features,
|
||||
mode,
|
||||
labels=None,
|
||||
train_op_fn=None,
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""Returns `ModelFnOps` that a model_fn can return.
|
||||
|
||||
Please note that,
|
||||
+ Exactly one of `logits` and `logits_input` must be provided.
|
||||
+ All args must be passed via name.
|
||||
|
||||
Args:
|
||||
features: Input `dict` of `Tensor` objects.
|
||||
mode: Estimator's `ModeKeys`.
|
||||
labels: Labels `Tensor`, or `dict` of same.
|
||||
train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
|
||||
to optimize the model with the loss. This is used in TRAIN mode and
|
||||
must not be None. None is allowed in other modes. If you want to
|
||||
optimize loss yourself you can pass `no_op_train_fn` and then use
|
||||
ModeFnOps.loss to compute and apply gradients.
|
||||
logits: logits `Tensor` to be used by the head.
|
||||
logits_input: `Tensor` from which to build logits, often needed when you
|
||||
don't want to compute the logits. Typicaly this is the activation of the
|
||||
last hidden layer in a DNN. Some heads (like the ones responsible for
|
||||
candidate sampling) intrinsically avoid computing full logits and only
|
||||
accepts logits_input.
|
||||
scope: Optional scope for `variable_scope`.
|
||||
|
||||
Returns:
|
||||
An instance of `ModelFnOps`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `mode` is not recognized.
|
||||
ValueError: If neither or both of `logits` and `logits_input` is provided.
|
||||
"""
|
||||
raise NotImplementedError("Calling an abstract method.")
|
||||
|
||||
|
||||
def _regression_head(label_name=None,
|
||||
def regression_head(label_name=None,
|
||||
weight_column_name=None,
|
||||
label_dimension=1,
|
||||
enable_centered_bias=False,
|
||||
head_name=None):
|
||||
"""Creates a _Head for linear regression.
|
||||
"""Creates a `Head` for linear regression.
|
||||
|
||||
Args:
|
||||
label_name: String, name of the key in label dict. Can be null if label
|
||||
@ -73,7 +199,7 @@ def _regression_head(label_name=None,
|
||||
will be `head_name`.
|
||||
|
||||
Returns:
|
||||
An instance of _Head
|
||||
An instance of `Head` for linear regression.
|
||||
"""
|
||||
return _RegressionHead(
|
||||
label_name=label_name,
|
||||
@ -85,12 +211,12 @@ def _regression_head(label_name=None,
|
||||
link_fn=array_ops.identity)
|
||||
|
||||
|
||||
def _poisson_regression_head(label_name=None,
|
||||
def poisson_regression_head(label_name=None,
|
||||
weight_column_name=None,
|
||||
label_dimension=1,
|
||||
enable_centered_bias=False,
|
||||
head_name=None):
|
||||
"""Creates a _Head for linear regression.
|
||||
"""Creates a `Head` for poisson regression.
|
||||
|
||||
Args:
|
||||
label_name: String, name of the key in label dict. Can be null if label
|
||||
@ -109,7 +235,7 @@ def _poisson_regression_head(label_name=None,
|
||||
will be `head_name`.
|
||||
|
||||
Returns:
|
||||
An instance of _Head
|
||||
An instance of `Head` for poisson regression.
|
||||
"""
|
||||
return _RegressionHead(
|
||||
label_name=label_name,
|
||||
@ -120,10 +246,10 @@ def _poisson_regression_head(label_name=None,
|
||||
loss_fn=_poisson_loss,
|
||||
link_fn=math_ops.exp)
|
||||
|
||||
# TODO(zakaria): Add logistic_regression_head
|
||||
# TODO(zakaria): Consider adding a _RegressionHead for logistic_regression
|
||||
|
||||
|
||||
def _multi_class_head(n_classes,
|
||||
def multi_class_head(n_classes,
|
||||
label_name=None,
|
||||
weight_column_name=None,
|
||||
enable_centered_bias=False,
|
||||
@ -131,7 +257,7 @@ def _multi_class_head(n_classes,
|
||||
thresholds=None,
|
||||
metric_class_ids=None,
|
||||
loss_fn=None):
|
||||
"""Creates a _Head for multi class single label classification.
|
||||
"""Creates a `Head` for multi class single label classification.
|
||||
|
||||
The Head uses softmax cross entropy loss.
|
||||
|
||||
@ -157,7 +283,7 @@ def _multi_class_head(n_classes,
|
||||
optional. See `tf.losses`
|
||||
|
||||
Returns:
|
||||
An instance of _MultiClassHead.
|
||||
An instance of `Head` for multi class classification.
|
||||
|
||||
Raises:
|
||||
ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when
|
||||
@ -193,13 +319,13 @@ def _multi_class_head(n_classes,
|
||||
loss_fn=loss_fn)
|
||||
|
||||
|
||||
def _binary_svm_head(
|
||||
def binary_svm_head(
|
||||
label_name=None,
|
||||
weight_column_name=None,
|
||||
enable_centered_bias=False,
|
||||
head_name=None,
|
||||
thresholds=None,):
|
||||
"""Creates a `_Head` for binary classification with SVMs.
|
||||
"""Creates a `Head` for binary classification with SVMs.
|
||||
|
||||
The head uses binary hinge loss.
|
||||
|
||||
@ -218,8 +344,7 @@ def _binary_svm_head(
|
||||
thresholds: thresholds for eval metrics, defaults to [.5]
|
||||
|
||||
Returns:
|
||||
An instance of `_Head`.
|
||||
|
||||
An instance of `Head` for binary classification with SVM.
|
||||
"""
|
||||
return _BinarySvmHead(
|
||||
label_name=label_name,
|
||||
@ -229,7 +354,7 @@ def _binary_svm_head(
|
||||
thresholds=thresholds)
|
||||
|
||||
|
||||
def _multi_label_head(n_classes,
|
||||
def multi_label_head(n_classes,
|
||||
label_name=None,
|
||||
weight_column_name=None,
|
||||
enable_centered_bias=False,
|
||||
@ -237,7 +362,7 @@ def _multi_label_head(n_classes,
|
||||
thresholds=None,
|
||||
metric_class_ids=None,
|
||||
loss_fn=None):
|
||||
"""Creates a _Head for multi label classification.
|
||||
"""Creates a Head for multi label classification.
|
||||
|
||||
The Head uses sigmoid cross entropy loss.
|
||||
|
||||
@ -262,7 +387,7 @@ def _multi_label_head(n_classes,
|
||||
optional. See `tf.losses`
|
||||
|
||||
Returns:
|
||||
An instance of _MultiLabelHead.
|
||||
An instance of `Head` for multi label classification.
|
||||
|
||||
Raises:
|
||||
ValueError: If n_classes is < 2
|
||||
@ -284,16 +409,16 @@ def _multi_label_head(n_classes,
|
||||
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
||||
|
||||
|
||||
def _multi_head(heads, loss_weights=None):
|
||||
def multi_head(heads, loss_weights=None):
|
||||
"""Creates a MultiHead stemming from same logits/hidden layer.
|
||||
|
||||
Args:
|
||||
heads: list of _Head objects.
|
||||
loss_weights: optional list of weights to be used to combine losses from
|
||||
heads: list of Head objects.
|
||||
loss_weights: optional list of weights to be used to merge losses from
|
||||
each head. All losses are weighted equally if not provided.
|
||||
|
||||
Returns:
|
||||
A _Head instance that combines multiple heads.
|
||||
A instance of `Head` that merges multiple heads.
|
||||
|
||||
Raises:
|
||||
ValueError: if heads and loss_weights have different size.
|
||||
@ -302,7 +427,7 @@ def _multi_head(heads, loss_weights=None):
|
||||
if len(loss_weights) != len(heads):
|
||||
raise ValueError("heads and loss_weights must have same size")
|
||||
|
||||
def _weighted_loss_combiner(losses):
|
||||
def _weighted_loss_merger(losses):
|
||||
if loss_weights:
|
||||
if len(losses) != len(loss_weights):
|
||||
raise ValueError("losses and loss_weights must have same size")
|
||||
@ -313,7 +438,7 @@ def _multi_head(heads, loss_weights=None):
|
||||
else:
|
||||
return math_ops.add_n(losses)
|
||||
|
||||
return _MultiHead(heads, loss_combiner=_weighted_loss_combiner)
|
||||
return _MultiHead(heads, loss_merger=_weighted_loss_merger)
|
||||
|
||||
|
||||
def no_op_train_fn(loss):
|
||||
@ -321,64 +446,7 @@ def no_op_train_fn(loss):
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
|
||||
# TODO(zakaria): Make the classes public once we are ready for users to subclass
|
||||
# them. See b/34751732
|
||||
class _Head(object):
|
||||
"""Interface for the head/top of a model.
|
||||
|
||||
Given logits or output of a hidden layer, a Head knows how to compute
|
||||
predictions, loss, default metric and export signature.
|
||||
"""
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@abc.abstractproperty
|
||||
def logits_dimension(self):
|
||||
"""Size of the last dimension of the logits `Tensor`.
|
||||
|
||||
Typically, logits is of shape `[batch_size, logits_dimension]`.
|
||||
|
||||
Returns:
|
||||
Number of logits values per example.
|
||||
"""
|
||||
raise NotImplementedError("Calling an abstract method.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def create_model_fn_ops(self,
|
||||
features,
|
||||
mode,
|
||||
labels=None,
|
||||
train_op_fn=None,
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""Returns ops for a model_fn.
|
||||
|
||||
Exactly one of `logits` and `logits_input` must be provided.
|
||||
|
||||
All args must be passed via name.
|
||||
|
||||
Args:
|
||||
features: Input `dict` of `Tensor` objects.
|
||||
mode: Estimator's `ModeKeys`.
|
||||
labels: Labels `Tensor`, or `dict` of same.
|
||||
train_op_fn: Function that takes a scalar loss and returns an op to
|
||||
optimize with the loss. Must not be `None` in TRAIN mode. If you want
|
||||
to optimize loss yourself you can pass `no_op_train_fn`.
|
||||
logits: logits `Tensor`, or `dict` of same, to be used for the head.
|
||||
logits_input: `Tensor` from which to build logits.
|
||||
scope: Optional scope for `variable_scope`.
|
||||
|
||||
Returns:
|
||||
`ModelFnOps`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `mode` is not recognized, or neither or both of `logits`
|
||||
and `logits_input` is provided.
|
||||
"""
|
||||
raise NotImplementedError("Calling an abstract method.")
|
||||
|
||||
|
||||
class _SingleHead(_Head):
|
||||
class _SingleHead(Head):
|
||||
"""Interface for a single head/top of a model."""
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
@ -565,7 +633,7 @@ def _create_model_fn_ops(features,
|
||||
|
||||
|
||||
class _RegressionHead(_SingleHead):
|
||||
"""_Head for regression with a generalized linear model."""
|
||||
"""`Head` for regression with a generalized linear model."""
|
||||
|
||||
def __init__(self,
|
||||
label_dimension,
|
||||
@ -575,7 +643,7 @@ class _RegressionHead(_SingleHead):
|
||||
weight_column_name=None,
|
||||
enable_centered_bias=False,
|
||||
head_name=None):
|
||||
"""Head for regression.
|
||||
"""`Head` for regression.
|
||||
|
||||
Args:
|
||||
label_dimension: Number of regression labels per example. This is the
|
||||
@ -614,7 +682,7 @@ class _RegressionHead(_SingleHead):
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""See `_Head`."""
|
||||
"""See `Head`."""
|
||||
return _create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
@ -682,7 +750,7 @@ def _one_class_to_two_class_logits(logits):
|
||||
|
||||
|
||||
class _BinaryLogisticHead(_SingleHead):
|
||||
"""_Head for binary logistic classifciation."""
|
||||
"""`Head` for binary classification with logistic regression."""
|
||||
|
||||
def __init__(self,
|
||||
label_name=None,
|
||||
@ -691,7 +759,7 @@ class _BinaryLogisticHead(_SingleHead):
|
||||
head_name=None,
|
||||
loss_fn=None,
|
||||
thresholds=None):
|
||||
"""Base type for all single heads.
|
||||
"""`Head` for binary classification with logistic regression.
|
||||
|
||||
Args:
|
||||
label_name: String, name of the key in label dict. Can be `None` if label
|
||||
@ -729,7 +797,7 @@ class _BinaryLogisticHead(_SingleHead):
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""See `_Head`."""
|
||||
"""See `Head`."""
|
||||
return _create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
@ -844,7 +912,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None):
|
||||
|
||||
|
||||
class _MultiClassHead(_SingleHead):
|
||||
"""_Head for classification."""
|
||||
"""'Head' for multi class classification."""
|
||||
|
||||
def __init__(self,
|
||||
n_classes,
|
||||
@ -855,7 +923,7 @@ class _MultiClassHead(_SingleHead):
|
||||
loss_fn=None,
|
||||
thresholds=None,
|
||||
metric_class_ids=None):
|
||||
"""_Head for classification.
|
||||
"""'Head' for multi class classification.
|
||||
|
||||
Args:
|
||||
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
|
||||
@ -905,7 +973,7 @@ class _MultiClassHead(_SingleHead):
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""See `_Head`."""
|
||||
"""See `Head`."""
|
||||
return _create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
@ -1039,7 +1107,7 @@ def _assert_labels_rank(labels):
|
||||
|
||||
|
||||
class _BinarySvmHead(_SingleHead):
|
||||
"""_Head for binary classification using SVMs."""
|
||||
"""`Head` for binary classification using SVM."""
|
||||
|
||||
def __init__(self, label_name, weight_column_name, enable_centered_bias,
|
||||
head_name, thresholds):
|
||||
@ -1069,7 +1137,7 @@ class _BinarySvmHead(_SingleHead):
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""See `_Head`."""
|
||||
"""See `Head`."""
|
||||
return _create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
@ -1125,7 +1193,7 @@ class _BinarySvmHead(_SingleHead):
|
||||
|
||||
|
||||
class _MultiLabelHead(_SingleHead):
|
||||
"""_Head for multlabel classification."""
|
||||
"""`Head` for multi-label classification."""
|
||||
|
||||
# TODO(zakaria): add signature and metric for multilabel.
|
||||
def __init__(self,
|
||||
@ -1162,7 +1230,7 @@ class _MultiLabelHead(_SingleHead):
|
||||
logits=None,
|
||||
logits_input=None,
|
||||
scope=None):
|
||||
"""See `_Head`."""
|
||||
"""See `Head`."""
|
||||
return _create_model_fn_ops(
|
||||
features=features,
|
||||
mode=mode,
|
||||
@ -1240,24 +1308,52 @@ class _MultiLabelHead(_SingleHead):
|
||||
return metrics
|
||||
|
||||
|
||||
class _MultiHead(_Head):
|
||||
"""_Head to combine multiple _Head objects.
|
||||
class _MultiHead(Head):
|
||||
"""`Head` implementation for multi objective learning.
|
||||
|
||||
This class is responsible for using and merging the output of multiple
|
||||
`Head` objects.
|
||||
|
||||
All heads stem from the same logits/logit_input tensor.
|
||||
|
||||
For training, combines losses of each heads according a function provided by
|
||||
user.
|
||||
For eval, adds a /head_name suffix to the keys in eval metrics.
|
||||
For inference, updates keys prediction dict to a 2-tuple,
|
||||
Common usage:
|
||||
For simple use cases you can pass the activation of hidden layer like
|
||||
this from your model_fn,
|
||||
```python
|
||||
last_hidden_layer_activation = ... Build your model.
|
||||
multi_head = ...
|
||||
return multi_head.create_model_fn_ops(
|
||||
..., logits_input=last_hidden_layer_activation, ...)
|
||||
```
|
||||
|
||||
Or you can create a logits tensor of
|
||||
[batch_size, multi_head.logits_dimension] shape. _MultiHead will split the
|
||||
logits for you.
|
||||
return multi_head.create_model_fn_ops(..., logits=logits, ...)
|
||||
|
||||
For more complex use cases like a multi-task/multi-tower model or when logits
|
||||
for each head has to be created separately, you can pass a dict of logits
|
||||
where the keys match the name of the single heads.
|
||||
```python
|
||||
logits = {"head1": logits1, "head2": logits2}
|
||||
return multi_head.create_model_fn_ops(..., logits=logits, ...)
|
||||
```
|
||||
|
||||
Here is what this class does,
|
||||
+ For training, merges losses of each heads according a function provided by
|
||||
user, calls user provided train_op_fn with this final loss.
|
||||
+ For eval, merges metrics by adding head_name suffix to the keys in eval
|
||||
metrics.
|
||||
+ For inference, updates keys in prediction dict to a 2-tuple,
|
||||
(head_name, prediction_key)
|
||||
"""
|
||||
|
||||
def __init__(self, heads, loss_combiner):
|
||||
"""_Head to combine multiple _Head objects.
|
||||
def __init__(self, heads, loss_merger):
|
||||
"""_Head to merges multiple _Head objects.
|
||||
|
||||
Args:
|
||||
heads: list of _Head objects.
|
||||
loss_combiner: function that takes a list of loss tensors for the heads
|
||||
loss_merger: function that takes a list of loss tensors for the heads
|
||||
and returns the final loss tensor for the multi head.
|
||||
|
||||
Raises:
|
||||
@ -1274,7 +1370,7 @@ class _MultiHead(_Head):
|
||||
self._logits_dimension += head.logits_dimension
|
||||
|
||||
self._heads = heads
|
||||
self._loss_combiner = loss_combiner
|
||||
self._loss_merger = loss_merger
|
||||
|
||||
@property
|
||||
def logits_dimension(self):
|
||||
@ -1353,11 +1449,11 @@ class _MultiHead(_Head):
|
||||
if mode == model_fn.ModeKeys.TRAIN:
|
||||
if train_op_fn is None:
|
||||
raise ValueError("train_op_fn can not be None in TRAIN mode.")
|
||||
return self._combine_train(all_model_fn_ops, train_op_fn)
|
||||
return self._merge_train(all_model_fn_ops, train_op_fn)
|
||||
if mode == model_fn.ModeKeys.INFER:
|
||||
return self._combine_infer(all_model_fn_ops)
|
||||
return self._merge_infer(all_model_fn_ops)
|
||||
if mode == model_fn.ModeKeys.EVAL:
|
||||
return self._combine_eval(all_model_fn_ops)
|
||||
return self._merge_eval(all_model_fn_ops)
|
||||
raise ValueError("mode=%s unrecognized" % str(mode))
|
||||
|
||||
def _split_logits(self, logits):
|
||||
@ -1379,8 +1475,8 @@ class _MultiHead(_Head):
|
||||
begin += current_logits_size
|
||||
return all_logits
|
||||
|
||||
def _combine_train(self, all_model_fn_ops, train_op_fn):
|
||||
"""Combines list of ModelFnOps for training.
|
||||
def _merge_train(self, all_model_fn_ops, train_op_fn):
|
||||
"""Merges list of ModelFnOps for training.
|
||||
|
||||
Args:
|
||||
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
||||
@ -1388,14 +1484,14 @@ class _MultiHead(_Head):
|
||||
documentaion for more details.
|
||||
|
||||
Returns:
|
||||
ModelFnOps that combines all the heads.
|
||||
ModelFnOps that merges all heads for TRAIN.
|
||||
"""
|
||||
losses = []
|
||||
additional_train_ops = []
|
||||
for m in all_model_fn_ops:
|
||||
losses.append(m.loss)
|
||||
additional_train_ops.append(m.train_op)
|
||||
loss = self._loss_combiner(losses)
|
||||
loss = self._loss_merger(losses)
|
||||
|
||||
train_op = train_op_fn(loss)
|
||||
train_op = control_flow_ops.group(train_op, *additional_train_ops)
|
||||
@ -1404,14 +1500,14 @@ class _MultiHead(_Head):
|
||||
loss=loss,
|
||||
train_op=train_op)
|
||||
|
||||
def _combine_infer(self, all_model_fn_ops):
|
||||
"""Combines list of ModelFnOps for inference.
|
||||
def _merge_infer(self, all_model_fn_ops):
|
||||
"""Merges list of ModelFnOps for inference.
|
||||
|
||||
Args:
|
||||
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
||||
|
||||
Returns:
|
||||
ModelFnOps that combines all the heads.
|
||||
ModelFnOps that Merges all the heads for INFER.
|
||||
"""
|
||||
predictions = {}
|
||||
output_alternatives = {}
|
||||
@ -1426,14 +1522,14 @@ class _MultiHead(_Head):
|
||||
predictions=predictions,
|
||||
output_alternatives=output_alternatives)
|
||||
|
||||
def _combine_eval(self, all_model_fn_ops):
|
||||
"""Combines list of ModelFnOps for eval.
|
||||
def _merge_eval(self, all_model_fn_ops):
|
||||
"""Merges list of ModelFnOps for eval.
|
||||
|
||||
Args:
|
||||
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
||||
|
||||
Returns:
|
||||
ModelFnOps that combines all the heads.
|
||||
ModelFnOps that merges all the heads for EVAL.
|
||||
"""
|
||||
predictions = {}
|
||||
metrics = {}
|
||||
@ -1446,7 +1542,7 @@ class _MultiHead(_Head):
|
||||
for k, v in m.eval_metric_ops.items():
|
||||
# metrics["%s/%s" % (k, head_name)] = v
|
||||
metrics[k] = v
|
||||
loss = self._loss_combiner(losses)
|
||||
loss = self._loss_merger(losses)
|
||||
|
||||
return model_fn.ModelFnOps(
|
||||
mode=model_fn.ModeKeys.EVAL,
|
||||
@ -1733,3 +1829,14 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
|
||||
predictions, labels=labels, thresholds=(threshold,),
|
||||
weights=_float_weights_or_none(weights))
|
||||
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
|
||||
|
||||
|
||||
# Aliases
|
||||
# TODO(zakaria): Remove these aliases, See b/34751732
|
||||
_regression_head = regression_head
|
||||
_poisson_regression_head = poisson_regression_head
|
||||
_multi_class_head = multi_class_head
|
||||
_binary_svm_head = binary_svm_head
|
||||
_multi_label_head = multi_label_head
|
||||
_multi_head = multi_head
|
||||
_Head = Head
|
||||
|
@ -112,7 +112,7 @@ class PoissonHeadTest(test.TestCase):
|
||||
return sum(lpl)/len(lpl)
|
||||
|
||||
def testPoissonWithLogits(self):
|
||||
head = head_lib._poisson_regression_head()
|
||||
head = head_lib.poisson_regression_head()
|
||||
labels = ((0.,), (1.,), (1.,))
|
||||
logits = ((0.,), (-1.,), (3.,))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
@ -140,7 +140,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
|
||||
# TODO(zakaria): test multilabel regression.
|
||||
def testRegressionWithLogits(self):
|
||||
head = head_lib._regression_head()
|
||||
head = head_lib.regression_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -154,7 +154,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
||||
|
||||
def testRegressionWithInvalidLogits(self):
|
||||
head = head_lib._regression_head()
|
||||
head = head_lib.regression_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
head.create_model_fn_ops(
|
||||
@ -165,7 +165,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
logits=((1., 1.), (1., 1.), (3., 1.)))
|
||||
|
||||
def testRegressionWithLogitsInput(self):
|
||||
head = head_lib._regression_head()
|
||||
head = head_lib.regression_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -183,7 +183,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
|
||||
|
||||
def testRegressionWithLogitsAndLogitsInput(self):
|
||||
head = head_lib._regression_head()
|
||||
head = head_lib.regression_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Both logits and logits_input supplied"):
|
||||
@ -196,7 +196,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
logits=((1.,), (1.,), (3.,)))
|
||||
|
||||
def testRegressionEvalMode(self):
|
||||
head = head_lib._regression_head()
|
||||
head = head_lib.regression_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -212,7 +212,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
|
||||
def testRegressionWithLabelName(self):
|
||||
label_name = "my_label"
|
||||
head = head_lib._regression_head(label_name=label_name)
|
||||
head = head_lib.regression_head(label_name=label_name)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -226,7 +226,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
||||
|
||||
def testRegressionWithWeights(self):
|
||||
head = head_lib._regression_head(weight_column_name="label_weight")
|
||||
head = head_lib.regression_head(weight_column_name="label_weight")
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
weights = ((2.,), (5.,), (0.,))
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
@ -242,7 +242,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
model_fn_ops)
|
||||
|
||||
def testRegressionWithCenteredBias(self):
|
||||
head = head_lib._regression_head(enable_centered_bias=True)
|
||||
head = head_lib.regression_head(enable_centered_bias=True)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -264,7 +264,7 @@ class RegressionHeadTest(test.TestCase):
|
||||
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
||||
|
||||
def testRegressionErrorInSparseTensorLabels(self):
|
||||
head = head_lib._regression_head()
|
||||
head = head_lib.regression_head()
|
||||
with ops.Graph().as_default():
|
||||
labels = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (2, 0)),
|
||||
@ -317,7 +317,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelWithLogits(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
@ -334,7 +334,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
n_classes = 2
|
||||
labels = ((0, 1),)
|
||||
logits = ((1., 0.),)
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
@ -361,7 +361,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
def testMultiLabelWithInvalidLogits(self):
|
||||
head = head_lib._multi_label_head(n_classes=len(self._labels[0]) + 1)
|
||||
head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
head.create_model_fn_ops(
|
||||
@ -370,7 +370,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelWithLogitsInput(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
@ -407,7 +407,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelWithLogitsAndLogitsInput(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(
|
||||
@ -418,7 +418,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelEvalMode(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
@ -434,7 +434,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassEvalModeWithLargeLogits(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
logits = ((2., 0., -1),)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
@ -474,7 +474,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
def testMultiLabelWithLabelName(self):
|
||||
n_classes = 3
|
||||
label_name = "my_label"
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes,
|
||||
label_name=label_name,
|
||||
metric_class_ids=range(n_classes))
|
||||
@ -491,7 +491,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelWithWeight(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name="label_weight",
|
||||
metric_class_ids=range(n_classes))
|
||||
@ -510,7 +510,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelWithCustomLoss(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name="label_weight",
|
||||
metric_class_ids=range(n_classes),
|
||||
@ -530,7 +530,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelWithCenteredBias(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes,
|
||||
enable_centered_bias=True,
|
||||
metric_class_ids=range(n_classes))
|
||||
@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelSparseTensorLabels(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
labels = sparse_tensor.SparseTensorValue(
|
||||
@ -580,7 +580,7 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
|
||||
def testMultiLabelSparseTensorLabelsTooFewClasses(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_label_head(
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
# Set _logits_dimension (n_classes) to a lower value; if it's set to 1
|
||||
# upfront, the class throws an error during initialization.
|
||||
@ -629,7 +629,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testBinaryClassificationWithLogits(self):
|
||||
n_classes = 2
|
||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
@ -644,7 +644,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||
|
||||
def testBinaryClassificationWithInvalidLogits(self):
|
||||
head = head_lib._multi_class_head(n_classes=len(self._labels) + 1)
|
||||
head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
head.create_model_fn_ops(
|
||||
@ -653,7 +653,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testBinaryClassificationWithLogitsInput(self):
|
||||
n_classes = 2
|
||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
@ -682,7 +682,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
def testBinaryClassificationWithLogitsAndLogitsInput(self):
|
||||
head = head_lib._multi_class_head(n_classes=2)
|
||||
head = head_lib.multi_class_head(n_classes=2)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Both logits and logits_input supplied"):
|
||||
@ -692,7 +692,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testBinaryClassificationEvalMode(self):
|
||||
n_classes = 2
|
||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
@ -709,7 +709,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testBinaryClassificationInferMode(self):
|
||||
n_classes = 2
|
||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
@ -722,7 +722,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testBinaryClassificationInferMode_withWightColumn(self):
|
||||
n_classes = 2
|
||||
head = head_lib._multi_class_head(n_classes=n_classes,
|
||||
head = head_lib.multi_class_head(n_classes=n_classes,
|
||||
weight_column_name="label_weight")
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
@ -738,7 +738,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testErrorInSparseTensorLabels(self):
|
||||
n_classes = 2
|
||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
||||
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||
with ops.Graph().as_default():
|
||||
labels = sparse_tensor.SparseTensorValue(
|
||||
indices=((0, 0), (1, 0), (2, 0)),
|
||||
@ -755,7 +755,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testBinaryClassificationWithLabelName(self):
|
||||
label_name = "my_label"
|
||||
head = head_lib._multi_class_head(n_classes=2, label_name=label_name)
|
||||
head = head_lib.multi_class_head(n_classes=2, label_name=label_name)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
@ -774,7 +774,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
|
||||
def testBinaryClassificationWithWeights(self):
|
||||
n_classes = 2
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes, weight_column_name="label_weight")
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
weights = ((1.,), (0.,))
|
||||
@ -808,7 +808,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
model_fn_ops)
|
||||
|
||||
def testBinaryClassificationWithCustomLoss(self):
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=2, weight_column_name="label_weight",
|
||||
loss_fn=_sigmoid_cross_entropy)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
@ -844,7 +844,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
model_fn_ops)
|
||||
|
||||
def testBinaryClassificationWithCenteredBias(self):
|
||||
head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True)
|
||||
head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||
@ -904,7 +904,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassWithLogits(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
@ -920,7 +920,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||
|
||||
def testMultiClassWithInvalidLogits(self):
|
||||
head = head_lib._multi_class_head(n_classes=len(self._logits[0]) + 1)
|
||||
head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
head.create_model_fn_ops(
|
||||
@ -928,7 +928,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
logits=self._logits)
|
||||
|
||||
def testMultiClassWithNoneTrainOpFnInTrain(self):
|
||||
head = head_lib._multi_class_head(n_classes=3)
|
||||
head = head_lib.multi_class_head(n_classes=3)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "train_op_fn can not be None in TRAIN mode"):
|
||||
@ -939,7 +939,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassWithLogitsInput(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
@ -978,7 +978,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassWithLogitsAndLogitsInput(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(
|
||||
@ -989,7 +989,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassEvalMode(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
# logloss: z:label, x:logit
|
||||
@ -1007,7 +1007,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassEvalModeWithLargeLogits(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||
logits = ((2., 0., -1),)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
@ -1046,7 +1046,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassWithWeight(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name="label_weight",
|
||||
metric_class_ids=range(n_classes))
|
||||
@ -1069,7 +1069,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
|
||||
def testMultiClassWithCustomLoss(self):
|
||||
n_classes = 3
|
||||
head = head_lib._multi_class_head(
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name="label_weight",
|
||||
metric_class_ids=range(n_classes),
|
||||
@ -1094,7 +1094,7 @@ class MultiClassHeadTest(test.TestCase):
|
||||
def testInvalidNClasses(self):
|
||||
for n_classes in (None, -1, 0, 1):
|
||||
with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
|
||||
head_lib._multi_class_head(n_classes=n_classes)
|
||||
head_lib.multi_class_head(n_classes=n_classes)
|
||||
|
||||
|
||||
class BinarySvmHeadTest(test.TestCase):
|
||||
@ -1116,7 +1116,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
self._expected_losses = (.5, 0.)
|
||||
|
||||
def testBinarySVMWithLogits(self):
|
||||
head = head_lib._binary_svm_head()
|
||||
head = head_lib.binary_svm_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -1134,7 +1134,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
def testBinarySVMWithInvalidLogits(self):
|
||||
head = head_lib._binary_svm_head()
|
||||
head = head_lib.binary_svm_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
head.create_model_fn_ops(
|
||||
@ -1142,7 +1142,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
logits=np.ones((2, 2)))
|
||||
|
||||
def testBinarySVMWithLogitsInput(self):
|
||||
head = head_lib._binary_svm_head()
|
||||
head = head_lib.binary_svm_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -1164,7 +1164,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
def testBinarySVMWithLogitsAndLogitsInput(self):
|
||||
head = head_lib._binary_svm_head()
|
||||
head = head_lib.binary_svm_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Both logits and logits_input supplied"):
|
||||
@ -1177,7 +1177,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
logits=self._predictions)
|
||||
|
||||
def testBinarySVMEvalMode(self):
|
||||
head = head_lib._binary_svm_head()
|
||||
head = head_lib.binary_svm_head()
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -1197,7 +1197,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
|
||||
def testBinarySVMWithLabelName(self):
|
||||
label_name = "my_label"
|
||||
head = head_lib._binary_svm_head(label_name=label_name)
|
||||
head = head_lib.binary_svm_head(label_name=label_name)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -1215,7 +1215,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
def testBinarySVMWithWeights(self):
|
||||
head = head_lib._binary_svm_head(weight_column_name="weights")
|
||||
head = head_lib.binary_svm_head(weight_column_name="weights")
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
weights = (7., 11.)
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
@ -1235,7 +1235,7 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
}, model_fn_ops)
|
||||
|
||||
def testBinarySVMWithCenteredBias(self):
|
||||
head = head_lib._binary_svm_head(enable_centered_bias=True)
|
||||
head = head_lib.binary_svm_head(enable_centered_bias=True)
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
{},
|
||||
@ -1265,21 +1265,21 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
class MultiHeadTest(test.TestCase):
|
||||
|
||||
def testInvalidHeads(self):
|
||||
named_head = head_lib._multi_class_head(
|
||||
named_head = head_lib.multi_class_head(
|
||||
n_classes=3, label_name="label", head_name="head1")
|
||||
unnamed_head = head_lib._multi_class_head(
|
||||
unnamed_head = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label")
|
||||
with self.assertRaisesRegexp(ValueError, "must have names"):
|
||||
head_lib._multi_head((named_head, unnamed_head))
|
||||
head_lib.multi_head((named_head, unnamed_head))
|
||||
with self.assertRaisesRegexp(ValueError, "must be SingleHead"):
|
||||
head_lib._multi_head((named_head, head_lib._multi_head((named_head,))))
|
||||
head_lib.multi_head((named_head, head_lib.multi_head((named_head,))))
|
||||
|
||||
def testTrainWithNoneTrainOpFn(self):
|
||||
head1 = head_lib._multi_class_head(
|
||||
head1 = head_lib.multi_class_head(
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib._multi_class_head(
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib._multi_head((head1, head2))
|
||||
head = head_lib.multi_head((head1, head2))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
@ -1294,11 +1294,11 @@ class MultiHeadTest(test.TestCase):
|
||||
logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
|
||||
|
||||
def testTrain_withNoHeadWeights(self):
|
||||
head1 = head_lib._multi_class_head(
|
||||
head1 = head_lib.multi_class_head(
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib._multi_class_head(
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib._multi_head((head1, head2))
|
||||
head = head_lib.multi_head((head1, head2))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
@ -1320,11 +1320,11 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
||||
|
||||
def testTrain_withHeadWeights(self):
|
||||
head1 = head_lib._multi_class_head(
|
||||
head1 = head_lib.multi_class_head(
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib._multi_class_head(
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib._multi_head((head1, head2), (1, .5))
|
||||
head = head_lib.multi_head((head1, head2), (1, .5))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
@ -1345,11 +1345,11 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)
|
||||
|
||||
def testTrain_withDictLogits(self):
|
||||
head1 = head_lib._multi_class_head(
|
||||
head1 = head_lib.multi_class_head(
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib._multi_class_head(
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib._multi_head((head1, head2))
|
||||
head = head_lib.multi_head((head1, head2))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
@ -1372,11 +1372,11 @@ class MultiHeadTest(test.TestCase):
|
||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
||||
|
||||
def testInfer(self):
|
||||
head1 = head_lib._multi_class_head(
|
||||
head1 = head_lib.multi_class_head(
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib._multi_class_head(
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib._multi_head((head1, head2), (1, .5))
|
||||
head = head_lib.multi_head((head1, head2), (1, .5))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
@ -1422,11 +1422,11 @@ class MultiHeadTest(test.TestCase):
|
||||
), model_fn_ops.output_alternatives["head2"][1].keys())
|
||||
|
||||
def testEval(self):
|
||||
head1 = head_lib._multi_class_head(
|
||||
head1 = head_lib.multi_class_head(
|
||||
n_classes=3, label_name="label1", head_name="head1")
|
||||
head2 = head_lib._multi_class_head(
|
||||
head2 = head_lib.multi_class_head(
|
||||
n_classes=4, label_name="label2", head_name="head2")
|
||||
head = head_lib._multi_head((head1, head2), (1, .5))
|
||||
head = head_lib.multi_head((head1, head2), (1, .5))
|
||||
labels = {
|
||||
"label1": (1,),
|
||||
"label2": (1,)
|
||||
|
@ -419,7 +419,7 @@ class LinearClassifier(estimator.Estimator):
|
||||
enable_centered_bias = False
|
||||
logging.warning("centered_bias is not supported with SDCA, "
|
||||
"please disable it explicitly.")
|
||||
head = head_lib._multi_class_head( # pylint: disable=protected-access
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
@ -686,7 +686,7 @@ class LinearRegressor(estimator.Estimator):
|
||||
enable_centered_bias = False
|
||||
logging.warning("centered_bias is not supported with SDCA, "
|
||||
"please disable it explicitly.")
|
||||
head = head_lib._regression_head( # pylint: disable=protected-access
|
||||
head = head_lib.regression_head(
|
||||
weight_column_name=weight_column_name,
|
||||
label_dimension=label_dimension,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
@ -824,8 +824,7 @@ class LinearRegressor(estimator.Estimator):
|
||||
exports_to_keep=exports_to_keep)
|
||||
|
||||
|
||||
# TODO(zakaria): Make it public when b/34751732 is fixed.
|
||||
class _LinearEstimator(estimator.Estimator):
|
||||
class LinearEstimator(estimator.Estimator):
|
||||
"""Linear model with user specified head.
|
||||
|
||||
Train a generalized linear model to predict label value given observation of
|
||||
@ -840,9 +839,9 @@ class _LinearEstimator(estimator.Estimator):
|
||||
|
||||
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||
|
||||
estimator = _LinearEstimator(
|
||||
estimator = LinearEstimator(
|
||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||
head=head_lib._poisson_regression_head())
|
||||
head=head_lib.poisson_regression_head())
|
||||
|
||||
# Input builders
|
||||
def input_fn_train: # returns x, y
|
||||
@ -879,7 +878,7 @@ class _LinearEstimator(estimator.Estimator):
|
||||
_joint_weights=False,
|
||||
config=None,
|
||||
feature_engineering_fn=None):
|
||||
"""Construct a `_LinearEstimator` object.
|
||||
"""Construct a `LinearEstimator` object.
|
||||
|
||||
Args:
|
||||
feature_columns: An iterable containing all the feature columns used by
|
||||
@ -907,14 +906,14 @@ class _LinearEstimator(estimator.Estimator):
|
||||
into the model.
|
||||
|
||||
Returns:
|
||||
A `_LinearEstimator` estimator.
|
||||
A `LinearEstimator` estimator.
|
||||
|
||||
Raises:
|
||||
ValueError: if optimizer is not supported, e.g., SDCAOptimizer
|
||||
"""
|
||||
assert feature_columns
|
||||
if isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
|
||||
raise ValueError("_LinearEstimator does not support SDCA optimizer.")
|
||||
raise ValueError("LinearEstimator does not support SDCA optimizer.")
|
||||
|
||||
params = {
|
||||
"head": head,
|
||||
@ -923,7 +922,7 @@ class _LinearEstimator(estimator.Estimator):
|
||||
"gradient_clip_norm": gradient_clip_norm,
|
||||
"joint_weights": _joint_weights,
|
||||
}
|
||||
super(_LinearEstimator, self).__init__(
|
||||
super(LinearEstimator, self).__init__(
|
||||
model_fn=_linear_model_fn,
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
|
@ -1665,15 +1665,15 @@ class LinearEstimatorTest(test.TestCase):
|
||||
'feature', dimension=4)
|
||||
]
|
||||
exp = experiment.Experiment(
|
||||
estimator=linear._LinearEstimator(feature_columns=cont_features,
|
||||
head=head_lib._regression_head()),
|
||||
estimator=linear.LinearEstimator(feature_columns=cont_features,
|
||||
head=head_lib.regression_head()),
|
||||
train_input_fn=test_data.iris_input_logistic_fn,
|
||||
eval_input_fn=test_data.iris_input_logistic_fn)
|
||||
exp.test()
|
||||
|
||||
def testEstimatorContract(self):
|
||||
estimator_test_utils.assert_estimator_contract(self,
|
||||
linear._LinearEstimator)
|
||||
linear.LinearEstimator)
|
||||
|
||||
def testLinearRegression(self):
|
||||
"""Tests that loss goes down with training."""
|
||||
@ -1691,8 +1691,8 @@ class LinearEstimatorTest(test.TestCase):
|
||||
100)
|
||||
age = feature_column_lib.real_valued_column('age')
|
||||
|
||||
linear_estimator = linear._LinearEstimator(feature_columns=[age, language],
|
||||
head=head_lib._regression_head())
|
||||
linear_estimator = linear.LinearEstimator(feature_columns=[age, language],
|
||||
head=head_lib.regression_head())
|
||||
linear_estimator.fit(input_fn=input_fn, steps=100)
|
||||
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
|
||||
linear_estimator.fit(input_fn=input_fn, steps=400)
|
||||
@ -1717,9 +1717,9 @@ class LinearEstimatorTest(test.TestCase):
|
||||
100)
|
||||
age = feature_column_lib.real_valued_column('age')
|
||||
|
||||
linear_estimator = linear._LinearEstimator(
|
||||
linear_estimator = linear.LinearEstimator(
|
||||
feature_columns=[age, language],
|
||||
head=head_lib._poisson_regression_head())
|
||||
head=head_lib.poisson_regression_head())
|
||||
linear_estimator.fit(input_fn=input_fn, steps=10)
|
||||
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
|
||||
linear_estimator.fit(input_fn=input_fn, steps=100)
|
||||
@ -1736,8 +1736,8 @@ class LinearEstimatorTest(test.TestCase):
|
||||
sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
|
||||
example_id_column='example_id')
|
||||
with self.assertRaises(ValueError):
|
||||
linear._LinearEstimator(
|
||||
head=head_lib._regression_head(label_dimension=1),
|
||||
linear.LinearEstimator(
|
||||
head=head_lib.regression_head(label_dimension=1),
|
||||
feature_columns=[maintenance_cost, sq_footage],
|
||||
optimizer=sdca_optimizer,
|
||||
_joint_weights=True)
|
||||
|
@ -139,7 +139,7 @@ class SVM(estimator.Estimator):
|
||||
model_dir=model_dir,
|
||||
config=config,
|
||||
params={
|
||||
"head": head_lib._binary_svm_head( # pylint: disable=protected-access
|
||||
"head": head_lib.binary_svm_head(
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=False),
|
||||
"feature_columns": feature_columns,
|
||||
|
Loading…
Reference in New Issue
Block a user