Makes the head/multi-head API public and updates selected users while leaving other users using aliases.

This is CL#1 of a series of CLs to make head/multi-head API public and migrate all users. This CL,
+ Makes the Head interface and factory functions public.
+ Updates all tf-learn internal and SIR usage.
+ Leaves aliases for the legacy private names which will be removed with all existing usages in next CLs.
+ Also, updates documentation.
Change: 149613397
This commit is contained in:
Zakaria Haque 2017-03-08 19:51:12 -08:00 committed by TensorFlower Gardener
parent 5b3e560d2f
commit 58067591b6
12 changed files with 390 additions and 265 deletions

View File

@ -28,13 +28,24 @@ See the @{$python/contrib.learn} guide.
@@MetricSpec @@MetricSpec
@@PredictionKey @@PredictionKey
@@DNNClassifier @@DNNClassifier
@@DNNEstimator
@@DNNRegressor @@DNNRegressor
@@DNNLinearCombinedRegressor @@DNNLinearCombinedRegressor
@@DNNLinearCombinedClassifier @@DNNLinearCombinedClassifier
@@LinearClassifier @@LinearClassifier
@@LinearEstimator
@@LinearRegressor @@LinearRegressor
@@LogisticRegressor @@LogisticRegressor
@@Head
@@multi_class_head
@@multi_label_head
@@binary_svm_head
@@regression_head
@@poisson_regression_head
@@multi_head
@@no_op_train_fn
@@Experiment @@Experiment
@@ExportStrategy @@ExportStrategy
@@TaskType @@TaskType

View File

@ -295,6 +295,7 @@ from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNEstimator
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedRegressor from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedRegressor
@ -304,8 +305,17 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
from tensorflow.contrib.learn.python.learn.estimators.head import Head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
from tensorflow.contrib.learn.python.learn.estimators.head import no_op_train_fn
from tensorflow.contrib.learn.python.learn.estimators.head import poisson_regression_head
from tensorflow.contrib.learn.python.learn.estimators.head import regression_head
from tensorflow.contrib.learn.python.learn.estimators.kmeans import KMeansClustering from tensorflow.contrib.learn.python.learn.estimators.kmeans import KMeansClustering
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey

View File

@ -131,7 +131,7 @@ class ComposableModelTest(test.TestCase):
language = feature_column.sparse_column_with_hash_bucket('language', 100) language = feature_column.sparse_column_with_hash_bucket('language', 100)
age = feature_column.real_valued_column('age') age = feature_column.real_valued_column('age')
head = head_lib._multi_class_head(n_classes=2) head = head_lib.multi_class_head(n_classes=2)
classifier = _linear_estimator(head, feature_columns=[age, language]) classifier = _linear_estimator(head, feature_columns=[age, language])
classifier.fit(input_fn=input_fn, steps=1000) classifier.fit(input_fn=input_fn, steps=1000)
@ -157,7 +157,7 @@ class ComposableModelTest(test.TestCase):
language = feature_column.sparse_column_with_hash_bucket('language', 100) language = feature_column.sparse_column_with_hash_bucket('language', 100)
age = feature_column.sparse_column_with_hash_bucket('age', 2) age = feature_column.sparse_column_with_hash_bucket('age', 2)
head = head_lib._multi_class_head(n_classes=2) head = head_lib.multi_class_head(n_classes=2)
classifier = _joint_linear_estimator(head, feature_columns=[age, language]) classifier = _joint_linear_estimator(head, feature_columns=[age, language])
classifier.fit(input_fn=input_fn, steps=1000) classifier.fit(input_fn=input_fn, steps=1000)
@ -171,7 +171,7 @@ class ComposableModelTest(test.TestCase):
"""Tests multi-class classification using matrix data as input.""" """Tests multi-class classification using matrix data as input."""
cont_features = [feature_column.real_valued_column('feature', dimension=4)] cont_features = [feature_column.real_valued_column('feature', dimension=4)]
head = head_lib._multi_class_head(n_classes=3) head = head_lib.multi_class_head(n_classes=3)
classifier = _dnn_estimator( classifier = _dnn_estimator(
head, feature_columns=cont_features, hidden_units=[3, 3]) head, feature_columns=cont_features, hidden_units=[3, 3])

View File

@ -304,7 +304,7 @@ class DNNClassifier(estimator.Estimator):
config=config, config=config,
params={ params={
"head": "head":
head_lib._multi_class_head( # pylint: disable=protected-access head_lib.multi_class_head(
n_classes, n_classes,
weight_column_name=weight_column_name, weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias), enable_centered_bias=enable_centered_bias),
@ -579,7 +579,7 @@ class DNNRegressor(estimator.Estimator):
config=config, config=config,
params={ params={
"head": "head":
head_lib._regression_head( # pylint: disable=protected-access head_lib.regression_head(
label_dimension=label_dimension, label_dimension=label_dimension,
weight_column_name=weight_column_name, weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias), enable_centered_bias=enable_centered_bias),
@ -731,8 +731,7 @@ class DNNRegressor(estimator.Estimator):
exports_to_keep=exports_to_keep) exports_to_keep=exports_to_keep)
# TODO(zakaria): Make it public when b/34751732 is fixed. class DNNEstimator(estimator.Estimator):
class _DNNEstimator(estimator.Estimator):
"""A Estimator for TensorFlow DNN models with user specified _Head. """A Estimator for TensorFlow DNN models with user specified _Head.
Example: Example:
@ -745,20 +744,20 @@ class _DNNEstimator(estimator.Estimator):
...) ...)
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b, sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
...) ...)
To create a _DNNEstimator for binary classification, where To create a DNNEstimator for binary classification, where
estimator = _DNNEstimator( estimator = DNNEstimator(
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
head=head_lib._multi_class__head(n_classes=2), head=tf.contrib.learn.multi_class_head(n_classes=2),
hidden_units=[1024, 512, 256]) hidden_units=[1024, 512, 256])
If your label is keyed with "y" in your labels dict, and weights are keyed If your label is keyed with "y" in your labels dict, and weights are keyed
with "w" in features dict, and you want to enable centered bias, with "w" in features dict, and you want to enable centered bias,
head = head_lib._multi_class__head( head = tf.contrib.learn.multi_class_head(
n_classes=2, n_classes=2,
label_name="x", label_name="x",
weight_column_name="w", weight_column_name="w",
enable_centered_bias=True) enable_centered_bias=True)
estimator = _DNNEstimator( estimator = DNNEstimator(
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
head=head, head=head,
hidden_units=[1024, 512, 256]) hidden_units=[1024, 512, 256])
@ -802,10 +801,10 @@ class _DNNEstimator(estimator.Estimator):
feature_engineering_fn=None, feature_engineering_fn=None,
embedding_lr_multipliers=None, embedding_lr_multipliers=None,
input_layer_min_slice_size=None): input_layer_min_slice_size=None):
"""Initializes a _DNNEstimator instance. """Initializes a `DNNEstimator` instance.
Args: Args:
head: _Head instance. head: `Head` instance.
hidden_units: List of hidden units per layer. All layers are fully hidden_units: List of hidden units per layer. All layers are fully
connected. Ex. `[64, 32]` means first layer has 64 nodes and second one connected. Ex. `[64, 32]` means first layer has 64 nodes and second one
has 32. has 32.
@ -836,9 +835,9 @@ class _DNNEstimator(estimator.Estimator):
partitions. If not provided, will use the default of 64M. partitions. If not provided, will use the default of 64M.
Returns: Returns:
A `_DNNEstimator` estimator. A `DNNEstimator` estimator.
""" """
super(_DNNEstimator, self).__init__( super(DNNEstimator, self).__init__(
model_fn=_dnn_model_fn, model_fn=_dnn_model_fn,
model_dir=model_dir, model_dir=model_dir,
config=config, config=config,
@ -854,4 +853,3 @@ class _DNNEstimator(estimator.Estimator):
"input_layer_min_slice_size": input_layer_min_slice_size, "input_layer_min_slice_size": input_layer_min_slice_size,
}, },
feature_engineering_fn=feature_engineering_fn) feature_engineering_fn=feature_engineering_fn)

View File

@ -550,7 +550,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
if not self._feature_columns: if not self._feature_columns:
raise ValueError("Either linear_feature_columns or dnn_feature_columns " raise ValueError("Either linear_feature_columns or dnn_feature_columns "
"must be defined.") "must be defined.")
head = head_lib._multi_class_head( # pylint: disable=protected-access head = head_lib.multi_class_head(
n_classes=n_classes, n_classes=n_classes,
weight_column_name=weight_column_name, weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias) enable_centered_bias=enable_centered_bias)
@ -841,7 +841,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
if not self._feature_columns: if not self._feature_columns:
raise ValueError("Either linear_feature_columns or dnn_feature_columns " raise ValueError("Either linear_feature_columns or dnn_feature_columns "
"must be defined.") "must be defined.")
head = head_lib._regression_head( # pylint: disable=protected-access head = head_lib.regression_head(
weight_column_name=weight_column_name, weight_column_name=weight_column_name,
label_dimension=label_dimension, label_dimension=label_dimension,
enable_centered_bias=enable_centered_bias) enable_centered_bias=enable_centered_bias)

View File

@ -60,7 +60,7 @@ def _assert_metrics_in_range(keys, metrics):
metrics) metrics)
class _CheckCallsHead(head_lib._Head): # pylint: disable=protected-access class _CheckCallsHead(head_lib.Head):
"""Head that checks whether head_ops is called.""" """Head that checks whether head_ops is called."""
def __init__(self): def __init__(self):
@ -97,7 +97,7 @@ class EmbeddingMultiplierTest(test.TestCase):
params = { params = {
'dnn_feature_columns': [one_hot_language], 'dnn_feature_columns': [one_hot_language],
'head': head_lib._multi_class_head(2), 'head': head_lib.multi_class_head(2),
'dnn_hidden_units': [1], 'dnn_hidden_units': [1],
# Set lr mult to 0. to keep embeddings constant. # Set lr mult to 0. to keep embeddings constant.
'embedding_lr_multipliers': { 'embedding_lr_multipliers': {
@ -131,7 +131,7 @@ class EmbeddingMultiplierTest(test.TestCase):
params = { params = {
'dnn_feature_columns': [embedding_language, embedding_wire], 'dnn_feature_columns': [embedding_language, embedding_wire],
'head': head_lib._multi_class_head(2), 'head': head_lib.multi_class_head(2),
'dnn_hidden_units': [1], 'dnn_hidden_units': [1],
# Set lr mult to 0. to keep embeddings constant. # Set lr mult to 0. to keep embeddings constant.
'embedding_lr_multipliers': { 'embedding_lr_multipliers': {

View File

@ -59,7 +59,7 @@ class EmbeddingMultiplierTest(test.TestCase):
params = { params = {
'feature_columns': [one_hot_language], 'feature_columns': [one_hot_language],
'head': head_lib._multi_class_head(2), 'head': head_lib.multi_class_head(2),
'hidden_units': [1], 'hidden_units': [1],
# Set lr mult to 0. to keep embeddings constant. # Set lr mult to 0. to keep embeddings constant.
'embedding_lr_multipliers': { 'embedding_lr_multipliers': {
@ -90,7 +90,7 @@ class EmbeddingMultiplierTest(test.TestCase):
params = { params = {
'feature_columns': [embedding_language, embedding_wire], 'feature_columns': [embedding_language, embedding_wire],
'head': head_lib._multi_class_head(2), 'head': head_lib.multi_class_head(2),
'hidden_units': [1], 'hidden_units': [1],
# Set lr mult to 0. to keep embeddings constant. # Set lr mult to 0. to keep embeddings constant.
'embedding_lr_multipliers': { 'embedding_lr_multipliers': {
@ -145,7 +145,7 @@ class DNNEstimatorTest(test.TestCase):
exp.test() exp.test()
def testEstimatorContract(self): def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(self, dnn._DNNEstimator) estimator_test_utils.assert_estimator_contract(self, dnn.DNNEstimator)
def testTrainWithWeights(self): def testTrainWithWeights(self):
"""Tests training with given weight column.""" """Tests training with given weight column."""
@ -172,8 +172,8 @@ class DNNEstimatorTest(test.TestCase):
} }
return features, labels return features, labels
dnn_estimator = dnn._DNNEstimator( dnn_estimator = dnn.DNNEstimator(
head=head_lib._multi_class_head(2, weight_column_name='w'), head=head_lib.multi_class_head(2, weight_column_name='w'),
feature_columns=[feature_column.real_valued_column('x')], feature_columns=[feature_column.real_valued_column('x')],
hidden_units=[3, 3], hidden_units=[3, 3],
config=run_config.RunConfig(tf_random_seed=1)) config=run_config.RunConfig(tf_random_seed=1))

View File

@ -46,15 +46,141 @@ from tensorflow.python.ops import variables
from tensorflow.python.summary import summary from tensorflow.python.summary import summary
from tensorflow.python.training import training from tensorflow.python.training import training
# TODO(zakaria): add functions that creates a head and returns ModelOpFn
class Head(object):
"""Interface for the head/top of a model.
Given logits (or output of a hidden layer), a Head knows how to compute
predictions, loss, default metric and export signature. It is meant to,
1) Simplify writing model_fn and to make model_fn more configurable
2) Support wide range of machine learning models. Since most heads can work
with logits, they can support DNN, RNN, Wide, Wide&Deep,
Global objectives, Gradient boosted trees and many other types
of machine learning models.
2) To allow users to seamlessly switch between 1 to n heads for multi
objective learning (See _MultiHead implementation for more details)
Common usage:
Here is simplified model_fn to build a multiclass DNN model.
```python
def _my_dnn_model_fn(features, labels, mode, params, config=None):
# Optionally your callers can pass head to model_fn as a param.
head = tf.contrib.learn.multi_class_head(...)
input = tf.contrib.layers.input_from_feature_columns(features, ...)
last_hidden_layer_out = tf.contrib.layers.stack(
input, tf.contrib.layers.fully_connected, [1000, 500])
logits = tf.contrib.layers.fully_connected(
last_hidden_layer_out, head.logits_dimension, activation_fn=None)
def _train_op_fn(loss):
return optimizer.minimize(loss)
return head.create_model_fn_ops(
features=features,
labels=labels,
mode=mode,
train_op_fn=_train_op_fn,
logits=logits,
scope=...)
```
Most heads also support logits_input which is typically the output of the last
hidden layer. Some heads (like heads responsible for candidate sampling or
hierarchical softmax) intrinsically will not support logits and you have
to pass logits_input. Here is a common usage,
```python
return head.create_model_fn_ops(
features=features,
labels=labels,
mode=mode,
train_op_fn=_train_op_fn,
logits_input=last_hidden_layer_out,
scope=...)
```python
There are cases where computing and applying gradients can not be meaningfully
captured with train_op_fn we support (for example, with sync optimizer). In
such case, you can take the responsibility on your own. Here is a common
use case,
```python
model_fn_ops = head.create_model_fn_ops(
features=features,
labels=labels,
mode=mode,
train_op_fn=tf.contrib.learn.no_op_train_fn,
logits=logits,
scope=...)
if mode == tf.contrib.learn.ModeKeys.TRAIN:
optimizer = ...
sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
loss=model_fn_ops.loss, ...)
hooks = [sync.make_session_run_hook(is_chief)]
... upate train_op and hooks in ModelFnOps and return
```
"""
__metaclass__ = abc.ABCMeta
@abc.abstractproperty
def logits_dimension(self):
"""Size of the last dimension of the logits `Tensor`.
Typically, logits is of shape `[batch_size, logits_dimension]`.
Returns:
The expected size of the `logits` tensor.
"""
raise NotImplementedError("Calling an abstract method.")
@abc.abstractmethod
def create_model_fn_ops(self,
features,
mode,
labels=None,
train_op_fn=None,
logits=None,
logits_input=None,
scope=None):
"""Returns `ModelFnOps` that a model_fn can return.
Please note that,
+ Exactly one of `logits` and `logits_input` must be provided.
+ All args must be passed via name.
Args:
features: Input `dict` of `Tensor` objects.
mode: Estimator's `ModeKeys`.
labels: Labels `Tensor`, or `dict` of same.
train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
to optimize the model with the loss. This is used in TRAIN mode and
must not be None. None is allowed in other modes. If you want to
optimize loss yourself you can pass `no_op_train_fn` and then use
ModeFnOps.loss to compute and apply gradients.
logits: logits `Tensor` to be used by the head.
logits_input: `Tensor` from which to build logits, often needed when you
don't want to compute the logits. Typicaly this is the activation of the
last hidden layer in a DNN. Some heads (like the ones responsible for
candidate sampling) intrinsically avoid computing full logits and only
accepts logits_input.
scope: Optional scope for `variable_scope`.
Returns:
An instance of `ModelFnOps`.
Raises:
ValueError: If `mode` is not recognized.
ValueError: If neither or both of `logits` and `logits_input` is provided.
"""
raise NotImplementedError("Calling an abstract method.")
def _regression_head(label_name=None, def regression_head(label_name=None,
weight_column_name=None, weight_column_name=None,
label_dimension=1, label_dimension=1,
enable_centered_bias=False, enable_centered_bias=False,
head_name=None): head_name=None):
"""Creates a _Head for linear regression. """Creates a `Head` for linear regression.
Args: Args:
label_name: String, name of the key in label dict. Can be null if label label_name: String, name of the key in label dict. Can be null if label
@ -73,7 +199,7 @@ def _regression_head(label_name=None,
will be `head_name`. will be `head_name`.
Returns: Returns:
An instance of _Head An instance of `Head` for linear regression.
""" """
return _RegressionHead( return _RegressionHead(
label_name=label_name, label_name=label_name,
@ -85,12 +211,12 @@ def _regression_head(label_name=None,
link_fn=array_ops.identity) link_fn=array_ops.identity)
def _poisson_regression_head(label_name=None, def poisson_regression_head(label_name=None,
weight_column_name=None, weight_column_name=None,
label_dimension=1, label_dimension=1,
enable_centered_bias=False, enable_centered_bias=False,
head_name=None): head_name=None):
"""Creates a _Head for linear regression. """Creates a `Head` for poisson regression.
Args: Args:
label_name: String, name of the key in label dict. Can be null if label label_name: String, name of the key in label dict. Can be null if label
@ -109,7 +235,7 @@ def _poisson_regression_head(label_name=None,
will be `head_name`. will be `head_name`.
Returns: Returns:
An instance of _Head An instance of `Head` for poisson regression.
""" """
return _RegressionHead( return _RegressionHead(
label_name=label_name, label_name=label_name,
@ -120,10 +246,10 @@ def _poisson_regression_head(label_name=None,
loss_fn=_poisson_loss, loss_fn=_poisson_loss,
link_fn=math_ops.exp) link_fn=math_ops.exp)
# TODO(zakaria): Add logistic_regression_head # TODO(zakaria): Consider adding a _RegressionHead for logistic_regression
def _multi_class_head(n_classes, def multi_class_head(n_classes,
label_name=None, label_name=None,
weight_column_name=None, weight_column_name=None,
enable_centered_bias=False, enable_centered_bias=False,
@ -131,7 +257,7 @@ def _multi_class_head(n_classes,
thresholds=None, thresholds=None,
metric_class_ids=None, metric_class_ids=None,
loss_fn=None): loss_fn=None):
"""Creates a _Head for multi class single label classification. """Creates a `Head` for multi class single label classification.
The Head uses softmax cross entropy loss. The Head uses softmax cross entropy loss.
@ -157,7 +283,7 @@ def _multi_class_head(n_classes,
optional. See `tf.losses` optional. See `tf.losses`
Returns: Returns:
An instance of _MultiClassHead. An instance of `Head` for multi class classification.
Raises: Raises:
ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when
@ -193,13 +319,13 @@ def _multi_class_head(n_classes,
loss_fn=loss_fn) loss_fn=loss_fn)
def _binary_svm_head( def binary_svm_head(
label_name=None, label_name=None,
weight_column_name=None, weight_column_name=None,
enable_centered_bias=False, enable_centered_bias=False,
head_name=None, head_name=None,
thresholds=None,): thresholds=None,):
"""Creates a `_Head` for binary classification with SVMs. """Creates a `Head` for binary classification with SVMs.
The head uses binary hinge loss. The head uses binary hinge loss.
@ -218,8 +344,7 @@ def _binary_svm_head(
thresholds: thresholds for eval metrics, defaults to [.5] thresholds: thresholds for eval metrics, defaults to [.5]
Returns: Returns:
An instance of `_Head`. An instance of `Head` for binary classification with SVM.
""" """
return _BinarySvmHead( return _BinarySvmHead(
label_name=label_name, label_name=label_name,
@ -229,7 +354,7 @@ def _binary_svm_head(
thresholds=thresholds) thresholds=thresholds)
def _multi_label_head(n_classes, def multi_label_head(n_classes,
label_name=None, label_name=None,
weight_column_name=None, weight_column_name=None,
enable_centered_bias=False, enable_centered_bias=False,
@ -237,7 +362,7 @@ def _multi_label_head(n_classes,
thresholds=None, thresholds=None,
metric_class_ids=None, metric_class_ids=None,
loss_fn=None): loss_fn=None):
"""Creates a _Head for multi label classification. """Creates a Head for multi label classification.
The Head uses sigmoid cross entropy loss. The Head uses sigmoid cross entropy loss.
@ -262,7 +387,7 @@ def _multi_label_head(n_classes,
optional. See `tf.losses` optional. See `tf.losses`
Returns: Returns:
An instance of _MultiLabelHead. An instance of `Head` for multi label classification.
Raises: Raises:
ValueError: If n_classes is < 2 ValueError: If n_classes is < 2
@ -284,16 +409,16 @@ def _multi_label_head(n_classes,
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
def _multi_head(heads, loss_weights=None): def multi_head(heads, loss_weights=None):
"""Creates a MultiHead stemming from same logits/hidden layer. """Creates a MultiHead stemming from same logits/hidden layer.
Args: Args:
heads: list of _Head objects. heads: list of Head objects.
loss_weights: optional list of weights to be used to combine losses from loss_weights: optional list of weights to be used to merge losses from
each head. All losses are weighted equally if not provided. each head. All losses are weighted equally if not provided.
Returns: Returns:
A _Head instance that combines multiple heads. A instance of `Head` that merges multiple heads.
Raises: Raises:
ValueError: if heads and loss_weights have different size. ValueError: if heads and loss_weights have different size.
@ -302,7 +427,7 @@ def _multi_head(heads, loss_weights=None):
if len(loss_weights) != len(heads): if len(loss_weights) != len(heads):
raise ValueError("heads and loss_weights must have same size") raise ValueError("heads and loss_weights must have same size")
def _weighted_loss_combiner(losses): def _weighted_loss_merger(losses):
if loss_weights: if loss_weights:
if len(losses) != len(loss_weights): if len(losses) != len(loss_weights):
raise ValueError("losses and loss_weights must have same size") raise ValueError("losses and loss_weights must have same size")
@ -313,7 +438,7 @@ def _multi_head(heads, loss_weights=None):
else: else:
return math_ops.add_n(losses) return math_ops.add_n(losses)
return _MultiHead(heads, loss_combiner=_weighted_loss_combiner) return _MultiHead(heads, loss_merger=_weighted_loss_merger)
def no_op_train_fn(loss): def no_op_train_fn(loss):
@ -321,64 +446,7 @@ def no_op_train_fn(loss):
return control_flow_ops.no_op() return control_flow_ops.no_op()
# TODO(zakaria): Make the classes public once we are ready for users to subclass class _SingleHead(Head):
# them. See b/34751732
class _Head(object):
"""Interface for the head/top of a model.
Given logits or output of a hidden layer, a Head knows how to compute
predictions, loss, default metric and export signature.
"""
__metaclass__ = abc.ABCMeta
@abc.abstractproperty
def logits_dimension(self):
"""Size of the last dimension of the logits `Tensor`.
Typically, logits is of shape `[batch_size, logits_dimension]`.
Returns:
Number of logits values per example.
"""
raise NotImplementedError("Calling an abstract method.")
@abc.abstractmethod
def create_model_fn_ops(self,
features,
mode,
labels=None,
train_op_fn=None,
logits=None,
logits_input=None,
scope=None):
"""Returns ops for a model_fn.
Exactly one of `logits` and `logits_input` must be provided.
All args must be passed via name.
Args:
features: Input `dict` of `Tensor` objects.
mode: Estimator's `ModeKeys`.
labels: Labels `Tensor`, or `dict` of same.
train_op_fn: Function that takes a scalar loss and returns an op to
optimize with the loss. Must not be `None` in TRAIN mode. If you want
to optimize loss yourself you can pass `no_op_train_fn`.
logits: logits `Tensor`, or `dict` of same, to be used for the head.
logits_input: `Tensor` from which to build logits.
scope: Optional scope for `variable_scope`.
Returns:
`ModelFnOps`.
Raises:
ValueError: if `mode` is not recognized, or neither or both of `logits`
and `logits_input` is provided.
"""
raise NotImplementedError("Calling an abstract method.")
class _SingleHead(_Head):
"""Interface for a single head/top of a model.""" """Interface for a single head/top of a model."""
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
@ -565,7 +633,7 @@ def _create_model_fn_ops(features,
class _RegressionHead(_SingleHead): class _RegressionHead(_SingleHead):
"""_Head for regression with a generalized linear model.""" """`Head` for regression with a generalized linear model."""
def __init__(self, def __init__(self,
label_dimension, label_dimension,
@ -575,7 +643,7 @@ class _RegressionHead(_SingleHead):
weight_column_name=None, weight_column_name=None,
enable_centered_bias=False, enable_centered_bias=False,
head_name=None): head_name=None):
"""Head for regression. """`Head` for regression.
Args: Args:
label_dimension: Number of regression labels per example. This is the label_dimension: Number of regression labels per example. This is the
@ -614,7 +682,7 @@ class _RegressionHead(_SingleHead):
logits=None, logits=None,
logits_input=None, logits_input=None,
scope=None): scope=None):
"""See `_Head`.""" """See `Head`."""
return _create_model_fn_ops( return _create_model_fn_ops(
features=features, features=features,
mode=mode, mode=mode,
@ -682,7 +750,7 @@ def _one_class_to_two_class_logits(logits):
class _BinaryLogisticHead(_SingleHead): class _BinaryLogisticHead(_SingleHead):
"""_Head for binary logistic classifciation.""" """`Head` for binary classification with logistic regression."""
def __init__(self, def __init__(self,
label_name=None, label_name=None,
@ -691,7 +759,7 @@ class _BinaryLogisticHead(_SingleHead):
head_name=None, head_name=None,
loss_fn=None, loss_fn=None,
thresholds=None): thresholds=None):
"""Base type for all single heads. """`Head` for binary classification with logistic regression.
Args: Args:
label_name: String, name of the key in label dict. Can be `None` if label label_name: String, name of the key in label dict. Can be `None` if label
@ -729,7 +797,7 @@ class _BinaryLogisticHead(_SingleHead):
logits=None, logits=None,
logits_input=None, logits_input=None,
scope=None): scope=None):
"""See `_Head`.""" """See `Head`."""
return _create_model_fn_ops( return _create_model_fn_ops(
features=features, features=features,
mode=mode, mode=mode,
@ -844,7 +912,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None):
class _MultiClassHead(_SingleHead): class _MultiClassHead(_SingleHead):
"""_Head for classification.""" """'Head' for multi class classification."""
def __init__(self, def __init__(self,
n_classes, n_classes,
@ -855,7 +923,7 @@ class _MultiClassHead(_SingleHead):
loss_fn=None, loss_fn=None,
thresholds=None, thresholds=None,
metric_class_ids=None): metric_class_ids=None):
"""_Head for classification. """'Head' for multi class classification.
Args: Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use n_classes: Number of classes, must be greater than 2 (for 2 classes, use
@ -905,7 +973,7 @@ class _MultiClassHead(_SingleHead):
logits=None, logits=None,
logits_input=None, logits_input=None,
scope=None): scope=None):
"""See `_Head`.""" """See `Head`."""
return _create_model_fn_ops( return _create_model_fn_ops(
features=features, features=features,
mode=mode, mode=mode,
@ -1039,7 +1107,7 @@ def _assert_labels_rank(labels):
class _BinarySvmHead(_SingleHead): class _BinarySvmHead(_SingleHead):
"""_Head for binary classification using SVMs.""" """`Head` for binary classification using SVM."""
def __init__(self, label_name, weight_column_name, enable_centered_bias, def __init__(self, label_name, weight_column_name, enable_centered_bias,
head_name, thresholds): head_name, thresholds):
@ -1069,7 +1137,7 @@ class _BinarySvmHead(_SingleHead):
logits=None, logits=None,
logits_input=None, logits_input=None,
scope=None): scope=None):
"""See `_Head`.""" """See `Head`."""
return _create_model_fn_ops( return _create_model_fn_ops(
features=features, features=features,
mode=mode, mode=mode,
@ -1125,7 +1193,7 @@ class _BinarySvmHead(_SingleHead):
class _MultiLabelHead(_SingleHead): class _MultiLabelHead(_SingleHead):
"""_Head for multlabel classification.""" """`Head` for multi-label classification."""
# TODO(zakaria): add signature and metric for multilabel. # TODO(zakaria): add signature and metric for multilabel.
def __init__(self, def __init__(self,
@ -1162,7 +1230,7 @@ class _MultiLabelHead(_SingleHead):
logits=None, logits=None,
logits_input=None, logits_input=None,
scope=None): scope=None):
"""See `_Head`.""" """See `Head`."""
return _create_model_fn_ops( return _create_model_fn_ops(
features=features, features=features,
mode=mode, mode=mode,
@ -1240,24 +1308,52 @@ class _MultiLabelHead(_SingleHead):
return metrics return metrics
class _MultiHead(_Head): class _MultiHead(Head):
"""_Head to combine multiple _Head objects. """`Head` implementation for multi objective learning.
This class is responsible for using and merging the output of multiple
`Head` objects.
All heads stem from the same logits/logit_input tensor. All heads stem from the same logits/logit_input tensor.
For training, combines losses of each heads according a function provided by Common usage:
user. For simple use cases you can pass the activation of hidden layer like
For eval, adds a /head_name suffix to the keys in eval metrics. this from your model_fn,
For inference, updates keys prediction dict to a 2-tuple, ```python
last_hidden_layer_activation = ... Build your model.
multi_head = ...
return multi_head.create_model_fn_ops(
..., logits_input=last_hidden_layer_activation, ...)
```
Or you can create a logits tensor of
[batch_size, multi_head.logits_dimension] shape. _MultiHead will split the
logits for you.
return multi_head.create_model_fn_ops(..., logits=logits, ...)
For more complex use cases like a multi-task/multi-tower model or when logits
for each head has to be created separately, you can pass a dict of logits
where the keys match the name of the single heads.
```python
logits = {"head1": logits1, "head2": logits2}
return multi_head.create_model_fn_ops(..., logits=logits, ...)
```
Here is what this class does,
+ For training, merges losses of each heads according a function provided by
user, calls user provided train_op_fn with this final loss.
+ For eval, merges metrics by adding head_name suffix to the keys in eval
metrics.
+ For inference, updates keys in prediction dict to a 2-tuple,
(head_name, prediction_key) (head_name, prediction_key)
""" """
def __init__(self, heads, loss_combiner): def __init__(self, heads, loss_merger):
"""_Head to combine multiple _Head objects. """_Head to merges multiple _Head objects.
Args: Args:
heads: list of _Head objects. heads: list of _Head objects.
loss_combiner: function that takes a list of loss tensors for the heads loss_merger: function that takes a list of loss tensors for the heads
and returns the final loss tensor for the multi head. and returns the final loss tensor for the multi head.
Raises: Raises:
@ -1274,7 +1370,7 @@ class _MultiHead(_Head):
self._logits_dimension += head.logits_dimension self._logits_dimension += head.logits_dimension
self._heads = heads self._heads = heads
self._loss_combiner = loss_combiner self._loss_merger = loss_merger
@property @property
def logits_dimension(self): def logits_dimension(self):
@ -1353,11 +1449,11 @@ class _MultiHead(_Head):
if mode == model_fn.ModeKeys.TRAIN: if mode == model_fn.ModeKeys.TRAIN:
if train_op_fn is None: if train_op_fn is None:
raise ValueError("train_op_fn can not be None in TRAIN mode.") raise ValueError("train_op_fn can not be None in TRAIN mode.")
return self._combine_train(all_model_fn_ops, train_op_fn) return self._merge_train(all_model_fn_ops, train_op_fn)
if mode == model_fn.ModeKeys.INFER: if mode == model_fn.ModeKeys.INFER:
return self._combine_infer(all_model_fn_ops) return self._merge_infer(all_model_fn_ops)
if mode == model_fn.ModeKeys.EVAL: if mode == model_fn.ModeKeys.EVAL:
return self._combine_eval(all_model_fn_ops) return self._merge_eval(all_model_fn_ops)
raise ValueError("mode=%s unrecognized" % str(mode)) raise ValueError("mode=%s unrecognized" % str(mode))
def _split_logits(self, logits): def _split_logits(self, logits):
@ -1379,8 +1475,8 @@ class _MultiHead(_Head):
begin += current_logits_size begin += current_logits_size
return all_logits return all_logits
def _combine_train(self, all_model_fn_ops, train_op_fn): def _merge_train(self, all_model_fn_ops, train_op_fn):
"""Combines list of ModelFnOps for training. """Merges list of ModelFnOps for training.
Args: Args:
all_model_fn_ops: list of ModelFnOps for the individual heads. all_model_fn_ops: list of ModelFnOps for the individual heads.
@ -1388,14 +1484,14 @@ class _MultiHead(_Head):
documentaion for more details. documentaion for more details.
Returns: Returns:
ModelFnOps that combines all the heads. ModelFnOps that merges all heads for TRAIN.
""" """
losses = [] losses = []
additional_train_ops = [] additional_train_ops = []
for m in all_model_fn_ops: for m in all_model_fn_ops:
losses.append(m.loss) losses.append(m.loss)
additional_train_ops.append(m.train_op) additional_train_ops.append(m.train_op)
loss = self._loss_combiner(losses) loss = self._loss_merger(losses)
train_op = train_op_fn(loss) train_op = train_op_fn(loss)
train_op = control_flow_ops.group(train_op, *additional_train_ops) train_op = control_flow_ops.group(train_op, *additional_train_ops)
@ -1404,14 +1500,14 @@ class _MultiHead(_Head):
loss=loss, loss=loss,
train_op=train_op) train_op=train_op)
def _combine_infer(self, all_model_fn_ops): def _merge_infer(self, all_model_fn_ops):
"""Combines list of ModelFnOps for inference. """Merges list of ModelFnOps for inference.
Args: Args:
all_model_fn_ops: list of ModelFnOps for the individual heads. all_model_fn_ops: list of ModelFnOps for the individual heads.
Returns: Returns:
ModelFnOps that combines all the heads. ModelFnOps that Merges all the heads for INFER.
""" """
predictions = {} predictions = {}
output_alternatives = {} output_alternatives = {}
@ -1426,14 +1522,14 @@ class _MultiHead(_Head):
predictions=predictions, predictions=predictions,
output_alternatives=output_alternatives) output_alternatives=output_alternatives)
def _combine_eval(self, all_model_fn_ops): def _merge_eval(self, all_model_fn_ops):
"""Combines list of ModelFnOps for eval. """Merges list of ModelFnOps for eval.
Args: Args:
all_model_fn_ops: list of ModelFnOps for the individual heads. all_model_fn_ops: list of ModelFnOps for the individual heads.
Returns: Returns:
ModelFnOps that combines all the heads. ModelFnOps that merges all the heads for EVAL.
""" """
predictions = {} predictions = {}
metrics = {} metrics = {}
@ -1446,7 +1542,7 @@ class _MultiHead(_Head):
for k, v in m.eval_metric_ops.items(): for k, v in m.eval_metric_ops.items():
# metrics["%s/%s" % (k, head_name)] = v # metrics["%s/%s" % (k, head_name)] = v
metrics[k] = v metrics[k] = v
loss = self._loss_combiner(losses) loss = self._loss_merger(losses)
return model_fn.ModelFnOps( return model_fn.ModelFnOps(
mode=model_fn.ModeKeys.EVAL, mode=model_fn.ModeKeys.EVAL,
@ -1733,3 +1829,14 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
predictions, labels=labels, thresholds=(threshold,), predictions, labels=labels, thresholds=(threshold,),
weights=_float_weights_or_none(weights)) weights=_float_weights_or_none(weights))
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
# Aliases
# TODO(zakaria): Remove these aliases, See b/34751732
_regression_head = regression_head
_poisson_regression_head = poisson_regression_head
_multi_class_head = multi_class_head
_binary_svm_head = binary_svm_head
_multi_label_head = multi_label_head
_multi_head = multi_head
_Head = Head

View File

@ -112,7 +112,7 @@ class PoissonHeadTest(test.TestCase):
return sum(lpl)/len(lpl) return sum(lpl)/len(lpl)
def testPoissonWithLogits(self): def testPoissonWithLogits(self):
head = head_lib._poisson_regression_head() head = head_lib.poisson_regression_head()
labels = ((0.,), (1.,), (1.,)) labels = ((0.,), (1.,), (1.,))
logits = ((0.,), (-1.,), (3.,)) logits = ((0.,), (-1.,), (3.,))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
@ -140,7 +140,7 @@ class RegressionHeadTest(test.TestCase):
# TODO(zakaria): test multilabel regression. # TODO(zakaria): test multilabel regression.
def testRegressionWithLogits(self): def testRegressionWithLogits(self):
head = head_lib._regression_head() head = head_lib.regression_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -154,7 +154,7 @@ class RegressionHeadTest(test.TestCase):
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithInvalidLogits(self): def testRegressionWithInvalidLogits(self):
head = head_lib._regression_head() head = head_lib.regression_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
head.create_model_fn_ops( head.create_model_fn_ops(
@ -165,7 +165,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1., 1.), (1., 1.), (3., 1.))) logits=((1., 1.), (1., 1.), (3., 1.)))
def testRegressionWithLogitsInput(self): def testRegressionWithLogitsInput(self):
head = head_lib._regression_head() head = head_lib.regression_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -183,7 +183,7 @@ class RegressionHeadTest(test.TestCase):
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops) _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
def testRegressionWithLogitsAndLogitsInput(self): def testRegressionWithLogitsAndLogitsInput(self):
head = head_lib._regression_head() head = head_lib.regression_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, "Both logits and logits_input supplied"): ValueError, "Both logits and logits_input supplied"):
@ -196,7 +196,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1.,), (1.,), (3.,))) logits=((1.,), (1.,), (3.,)))
def testRegressionEvalMode(self): def testRegressionEvalMode(self):
head = head_lib._regression_head() head = head_lib.regression_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -212,7 +212,7 @@ class RegressionHeadTest(test.TestCase):
def testRegressionWithLabelName(self): def testRegressionWithLabelName(self):
label_name = "my_label" label_name = "my_label"
head = head_lib._regression_head(label_name=label_name) head = head_lib.regression_head(label_name=label_name)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -226,7 +226,7 @@ class RegressionHeadTest(test.TestCase):
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithWeights(self): def testRegressionWithWeights(self):
head = head_lib._regression_head(weight_column_name="label_weight") head = head_lib.regression_head(weight_column_name="label_weight")
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
weights = ((2.,), (5.,), (0.,)) weights = ((2.,), (5.,), (0.,))
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
@ -242,7 +242,7 @@ class RegressionHeadTest(test.TestCase):
model_fn_ops) model_fn_ops)
def testRegressionWithCenteredBias(self): def testRegressionWithCenteredBias(self):
head = head_lib._regression_head(enable_centered_bias=True) head = head_lib.regression_head(enable_centered_bias=True)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -264,7 +264,7 @@ class RegressionHeadTest(test.TestCase):
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionErrorInSparseTensorLabels(self): def testRegressionErrorInSparseTensorLabels(self):
head = head_lib._regression_head() head = head_lib.regression_head()
with ops.Graph().as_default(): with ops.Graph().as_default():
labels = sparse_tensor.SparseTensorValue( labels = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (2, 0)), indices=((0, 0), (1, 0), (2, 0)),
@ -317,7 +317,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelWithLogits(self): def testMultiLabelWithLogits(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
@ -334,7 +334,7 @@ class MultiLabelHeadTest(test.TestCase):
n_classes = 2 n_classes = 2
labels = ((0, 1),) labels = ((0, 1),)
logits = ((1., 0.),) logits = ((1., 0.),)
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
@ -361,7 +361,7 @@ class MultiLabelHeadTest(test.TestCase):
}, model_fn_ops) }, model_fn_ops)
def testMultiLabelWithInvalidLogits(self): def testMultiLabelWithInvalidLogits(self):
head = head_lib._multi_label_head(n_classes=len(self._labels[0]) + 1) head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
head.create_model_fn_ops( head.create_model_fn_ops(
@ -370,7 +370,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelWithLogitsInput(self): def testMultiLabelWithLogitsInput(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
@ -407,7 +407,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelWithLogitsAndLogitsInput(self): def testMultiLabelWithLogitsAndLogitsInput(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -418,7 +418,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelEvalMode(self): def testMultiLabelEvalMode(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
@ -434,7 +434,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiClassEvalModeWithLargeLogits(self): def testMultiClassEvalModeWithLargeLogits(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
logits = ((2., 0., -1),) logits = ((2., 0., -1),)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
@ -474,7 +474,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelWithLabelName(self): def testMultiLabelWithLabelName(self):
n_classes = 3 n_classes = 3
label_name = "my_label" label_name = "my_label"
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, n_classes=n_classes,
label_name=label_name, label_name=label_name,
metric_class_ids=range(n_classes)) metric_class_ids=range(n_classes))
@ -491,7 +491,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelWithWeight(self): def testMultiLabelWithWeight(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, n_classes=n_classes,
weight_column_name="label_weight", weight_column_name="label_weight",
metric_class_ids=range(n_classes)) metric_class_ids=range(n_classes))
@ -510,7 +510,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelWithCustomLoss(self): def testMultiLabelWithCustomLoss(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, n_classes=n_classes,
weight_column_name="label_weight", weight_column_name="label_weight",
metric_class_ids=range(n_classes), metric_class_ids=range(n_classes),
@ -530,7 +530,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelWithCenteredBias(self): def testMultiLabelWithCenteredBias(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, n_classes=n_classes,
enable_centered_bias=True, enable_centered_bias=True,
metric_class_ids=range(n_classes)) metric_class_ids=range(n_classes))
@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelSparseTensorLabels(self): def testMultiLabelSparseTensorLabels(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
labels = sparse_tensor.SparseTensorValue( labels = sparse_tensor.SparseTensorValue(
@ -580,7 +580,7 @@ class MultiLabelHeadTest(test.TestCase):
def testMultiLabelSparseTensorLabelsTooFewClasses(self): def testMultiLabelSparseTensorLabelsTooFewClasses(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_label_head( head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
# Set _logits_dimension (n_classes) to a lower value; if it's set to 1 # Set _logits_dimension (n_classes) to a lower value; if it's set to 1
# upfront, the class throws an error during initialization. # upfront, the class throws an error during initialization.
@ -629,7 +629,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testBinaryClassificationWithLogits(self): def testBinaryClassificationWithLogits(self):
n_classes = 2 n_classes = 2
head = head_lib._multi_class_head(n_classes=n_classes) head = head_lib.multi_class_head(n_classes=n_classes)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
@ -644,7 +644,7 @@ class BinaryClassificationHeadTest(test.TestCase):
self._expected_eval_metrics(expected_loss), model_fn_ops) self._expected_eval_metrics(expected_loss), model_fn_ops)
def testBinaryClassificationWithInvalidLogits(self): def testBinaryClassificationWithInvalidLogits(self):
head = head_lib._multi_class_head(n_classes=len(self._labels) + 1) head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
head.create_model_fn_ops( head.create_model_fn_ops(
@ -653,7 +653,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testBinaryClassificationWithLogitsInput(self): def testBinaryClassificationWithLogitsInput(self):
n_classes = 2 n_classes = 2
head = head_lib._multi_class_head(n_classes=n_classes) head = head_lib.multi_class_head(n_classes=n_classes)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
@ -682,7 +682,7 @@ class BinaryClassificationHeadTest(test.TestCase):
}, model_fn_ops) }, model_fn_ops)
def testBinaryClassificationWithLogitsAndLogitsInput(self): def testBinaryClassificationWithLogitsAndLogitsInput(self):
head = head_lib._multi_class_head(n_classes=2) head = head_lib.multi_class_head(n_classes=2)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, "Both logits and logits_input supplied"): ValueError, "Both logits and logits_input supplied"):
@ -692,7 +692,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testBinaryClassificationEvalMode(self): def testBinaryClassificationEvalMode(self):
n_classes = 2 n_classes = 2
head = head_lib._multi_class_head(n_classes=n_classes) head = head_lib.multi_class_head(n_classes=n_classes)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
@ -709,7 +709,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testBinaryClassificationInferMode(self): def testBinaryClassificationInferMode(self):
n_classes = 2 n_classes = 2
head = head_lib._multi_class_head(n_classes=n_classes) head = head_lib.multi_class_head(n_classes=n_classes)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
@ -722,7 +722,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testBinaryClassificationInferMode_withWightColumn(self): def testBinaryClassificationInferMode_withWightColumn(self):
n_classes = 2 n_classes = 2
head = head_lib._multi_class_head(n_classes=n_classes, head = head_lib.multi_class_head(n_classes=n_classes,
weight_column_name="label_weight") weight_column_name="label_weight")
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
@ -738,7 +738,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testErrorInSparseTensorLabels(self): def testErrorInSparseTensorLabels(self):
n_classes = 2 n_classes = 2
head = head_lib._multi_class_head(n_classes=n_classes) head = head_lib.multi_class_head(n_classes=n_classes)
with ops.Graph().as_default(): with ops.Graph().as_default():
labels = sparse_tensor.SparseTensorValue( labels = sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (2, 0)), indices=((0, 0), (1, 0), (2, 0)),
@ -755,7 +755,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testBinaryClassificationWithLabelName(self): def testBinaryClassificationWithLabelName(self):
label_name = "my_label" label_name = "my_label"
head = head_lib._multi_class_head(n_classes=2, label_name=label_name) head = head_lib.multi_class_head(n_classes=2, label_name=label_name)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
@ -774,7 +774,7 @@ class BinaryClassificationHeadTest(test.TestCase):
def testBinaryClassificationWithWeights(self): def testBinaryClassificationWithWeights(self):
n_classes = 2 n_classes = 2
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, weight_column_name="label_weight") n_classes=n_classes, weight_column_name="label_weight")
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
weights = ((1.,), (0.,)) weights = ((1.,), (0.,))
@ -808,7 +808,7 @@ class BinaryClassificationHeadTest(test.TestCase):
model_fn_ops) model_fn_ops)
def testBinaryClassificationWithCustomLoss(self): def testBinaryClassificationWithCustomLoss(self):
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=2, weight_column_name="label_weight", n_classes=2, weight_column_name="label_weight",
loss_fn=_sigmoid_cross_entropy) loss_fn=_sigmoid_cross_entropy)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
@ -844,7 +844,7 @@ class BinaryClassificationHeadTest(test.TestCase):
model_fn_ops) model_fn_ops)
def testBinaryClassificationWithCenteredBias(self): def testBinaryClassificationWithCenteredBias(self):
head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True) head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
@ -904,7 +904,7 @@ class MultiClassHeadTest(test.TestCase):
def testMultiClassWithLogits(self): def testMultiClassWithLogits(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
@ -920,7 +920,7 @@ class MultiClassHeadTest(test.TestCase):
self._expected_eval_metrics(expected_loss), model_fn_ops) self._expected_eval_metrics(expected_loss), model_fn_ops)
def testMultiClassWithInvalidLogits(self): def testMultiClassWithInvalidLogits(self):
head = head_lib._multi_class_head(n_classes=len(self._logits[0]) + 1) head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
head.create_model_fn_ops( head.create_model_fn_ops(
@ -928,7 +928,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits) logits=self._logits)
def testMultiClassWithNoneTrainOpFnInTrain(self): def testMultiClassWithNoneTrainOpFnInTrain(self):
head = head_lib._multi_class_head(n_classes=3) head = head_lib.multi_class_head(n_classes=3)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, "train_op_fn can not be None in TRAIN mode"): ValueError, "train_op_fn can not be None in TRAIN mode"):
@ -939,7 +939,7 @@ class MultiClassHeadTest(test.TestCase):
def testMultiClassWithLogitsInput(self): def testMultiClassWithLogitsInput(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
@ -978,7 +978,7 @@ class MultiClassHeadTest(test.TestCase):
def testMultiClassWithLogitsAndLogitsInput(self): def testMultiClassWithLogitsAndLogitsInput(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -989,7 +989,7 @@ class MultiClassHeadTest(test.TestCase):
def testMultiClassEvalMode(self): def testMultiClassEvalMode(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit # logloss: z:label, x:logit
@ -1007,7 +1007,7 @@ class MultiClassHeadTest(test.TestCase):
def testMultiClassEvalModeWithLargeLogits(self): def testMultiClassEvalModeWithLargeLogits(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, metric_class_ids=range(n_classes)) n_classes=n_classes, metric_class_ids=range(n_classes))
logits = ((2., 0., -1),) logits = ((2., 0., -1),)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
@ -1046,7 +1046,7 @@ class MultiClassHeadTest(test.TestCase):
def testMultiClassWithWeight(self): def testMultiClassWithWeight(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, n_classes=n_classes,
weight_column_name="label_weight", weight_column_name="label_weight",
metric_class_ids=range(n_classes)) metric_class_ids=range(n_classes))
@ -1069,7 +1069,7 @@ class MultiClassHeadTest(test.TestCase):
def testMultiClassWithCustomLoss(self): def testMultiClassWithCustomLoss(self):
n_classes = 3 n_classes = 3
head = head_lib._multi_class_head( head = head_lib.multi_class_head(
n_classes=n_classes, n_classes=n_classes,
weight_column_name="label_weight", weight_column_name="label_weight",
metric_class_ids=range(n_classes), metric_class_ids=range(n_classes),
@ -1094,7 +1094,7 @@ class MultiClassHeadTest(test.TestCase):
def testInvalidNClasses(self): def testInvalidNClasses(self):
for n_classes in (None, -1, 0, 1): for n_classes in (None, -1, 0, 1):
with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"): with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
head_lib._multi_class_head(n_classes=n_classes) head_lib.multi_class_head(n_classes=n_classes)
class BinarySvmHeadTest(test.TestCase): class BinarySvmHeadTest(test.TestCase):
@ -1116,7 +1116,7 @@ class BinarySvmHeadTest(test.TestCase):
self._expected_losses = (.5, 0.) self._expected_losses = (.5, 0.)
def testBinarySVMWithLogits(self): def testBinarySVMWithLogits(self):
head = head_lib._binary_svm_head() head = head_lib.binary_svm_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -1134,7 +1134,7 @@ class BinarySvmHeadTest(test.TestCase):
}, model_fn_ops) }, model_fn_ops)
def testBinarySVMWithInvalidLogits(self): def testBinarySVMWithInvalidLogits(self):
head = head_lib._binary_svm_head() head = head_lib.binary_svm_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
head.create_model_fn_ops( head.create_model_fn_ops(
@ -1142,7 +1142,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=np.ones((2, 2))) logits=np.ones((2, 2)))
def testBinarySVMWithLogitsInput(self): def testBinarySVMWithLogitsInput(self):
head = head_lib._binary_svm_head() head = head_lib.binary_svm_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -1164,7 +1164,7 @@ class BinarySvmHeadTest(test.TestCase):
}, model_fn_ops) }, model_fn_ops)
def testBinarySVMWithLogitsAndLogitsInput(self): def testBinarySVMWithLogitsAndLogitsInput(self):
head = head_lib._binary_svm_head() head = head_lib.binary_svm_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, "Both logits and logits_input supplied"): ValueError, "Both logits and logits_input supplied"):
@ -1177,7 +1177,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions) logits=self._predictions)
def testBinarySVMEvalMode(self): def testBinarySVMEvalMode(self):
head = head_lib._binary_svm_head() head = head_lib.binary_svm_head()
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -1197,7 +1197,7 @@ class BinarySvmHeadTest(test.TestCase):
def testBinarySVMWithLabelName(self): def testBinarySVMWithLabelName(self):
label_name = "my_label" label_name = "my_label"
head = head_lib._binary_svm_head(label_name=label_name) head = head_lib.binary_svm_head(label_name=label_name)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -1215,7 +1215,7 @@ class BinarySvmHeadTest(test.TestCase):
}, model_fn_ops) }, model_fn_ops)
def testBinarySVMWithWeights(self): def testBinarySVMWithWeights(self):
head = head_lib._binary_svm_head(weight_column_name="weights") head = head_lib.binary_svm_head(weight_column_name="weights")
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
weights = (7., 11.) weights = (7., 11.)
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
@ -1235,7 +1235,7 @@ class BinarySvmHeadTest(test.TestCase):
}, model_fn_ops) }, model_fn_ops)
def testBinarySVMWithCenteredBias(self): def testBinarySVMWithCenteredBias(self):
head = head_lib._binary_svm_head(enable_centered_bias=True) head = head_lib.binary_svm_head(enable_centered_bias=True)
with ops.Graph().as_default(), session.Session(): with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops( model_fn_ops = head.create_model_fn_ops(
{}, {},
@ -1265,21 +1265,21 @@ class BinarySvmHeadTest(test.TestCase):
class MultiHeadTest(test.TestCase): class MultiHeadTest(test.TestCase):
def testInvalidHeads(self): def testInvalidHeads(self):
named_head = head_lib._multi_class_head( named_head = head_lib.multi_class_head(
n_classes=3, label_name="label", head_name="head1") n_classes=3, label_name="label", head_name="head1")
unnamed_head = head_lib._multi_class_head( unnamed_head = head_lib.multi_class_head(
n_classes=4, label_name="label") n_classes=4, label_name="label")
with self.assertRaisesRegexp(ValueError, "must have names"): with self.assertRaisesRegexp(ValueError, "must have names"):
head_lib._multi_head((named_head, unnamed_head)) head_lib.multi_head((named_head, unnamed_head))
with self.assertRaisesRegexp(ValueError, "must be SingleHead"): with self.assertRaisesRegexp(ValueError, "must be SingleHead"):
head_lib._multi_head((named_head, head_lib._multi_head((named_head,)))) head_lib.multi_head((named_head, head_lib.multi_head((named_head,))))
def testTrainWithNoneTrainOpFn(self): def testTrainWithNoneTrainOpFn(self):
head1 = head_lib._multi_class_head( head1 = head_lib.multi_class_head(
n_classes=3, label_name="label1", head_name="head1") n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib._multi_class_head( head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2") n_classes=4, label_name="label2", head_name="head2")
head = head_lib._multi_head((head1, head2)) head = head_lib.multi_head((head1, head2))
labels = { labels = {
"label1": (1,), "label1": (1,),
"label2": (1,) "label2": (1,)
@ -1294,11 +1294,11 @@ class MultiHeadTest(test.TestCase):
logits=((-0.7, 0.2, .1, .1, .1, .1, .1),)) logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
def testTrain_withNoHeadWeights(self): def testTrain_withNoHeadWeights(self):
head1 = head_lib._multi_class_head( head1 = head_lib.multi_class_head(
n_classes=3, label_name="label1", head_name="head1") n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib._multi_class_head( head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2") n_classes=4, label_name="label2", head_name="head2")
head = head_lib._multi_head((head1, head2)) head = head_lib.multi_head((head1, head2))
labels = { labels = {
"label1": (1,), "label1": (1,),
"label2": (1,) "label2": (1,)
@ -1320,11 +1320,11 @@ class MultiHeadTest(test.TestCase):
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
def testTrain_withHeadWeights(self): def testTrain_withHeadWeights(self):
head1 = head_lib._multi_class_head( head1 = head_lib.multi_class_head(
n_classes=3, label_name="label1", head_name="head1") n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib._multi_class_head( head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2") n_classes=4, label_name="label2", head_name="head2")
head = head_lib._multi_head((head1, head2), (1, .5)) head = head_lib.multi_head((head1, head2), (1, .5))
labels = { labels = {
"label1": (1,), "label1": (1,),
"label2": (1,) "label2": (1,)
@ -1345,11 +1345,11 @@ class MultiHeadTest(test.TestCase):
self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3) self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)
def testTrain_withDictLogits(self): def testTrain_withDictLogits(self):
head1 = head_lib._multi_class_head( head1 = head_lib.multi_class_head(
n_classes=3, label_name="label1", head_name="head1") n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib._multi_class_head( head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2") n_classes=4, label_name="label2", head_name="head2")
head = head_lib._multi_head((head1, head2)) head = head_lib.multi_head((head1, head2))
labels = { labels = {
"label1": (1,), "label1": (1,),
"label2": (1,) "label2": (1,)
@ -1372,11 +1372,11 @@ class MultiHeadTest(test.TestCase):
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
def testInfer(self): def testInfer(self):
head1 = head_lib._multi_class_head( head1 = head_lib.multi_class_head(
n_classes=3, label_name="label1", head_name="head1") n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib._multi_class_head( head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2") n_classes=4, label_name="label2", head_name="head2")
head = head_lib._multi_head((head1, head2), (1, .5)) head = head_lib.multi_head((head1, head2), (1, .5))
labels = { labels = {
"label1": (1,), "label1": (1,),
"label2": (1,) "label2": (1,)
@ -1422,11 +1422,11 @@ class MultiHeadTest(test.TestCase):
), model_fn_ops.output_alternatives["head2"][1].keys()) ), model_fn_ops.output_alternatives["head2"][1].keys())
def testEval(self): def testEval(self):
head1 = head_lib._multi_class_head( head1 = head_lib.multi_class_head(
n_classes=3, label_name="label1", head_name="head1") n_classes=3, label_name="label1", head_name="head1")
head2 = head_lib._multi_class_head( head2 = head_lib.multi_class_head(
n_classes=4, label_name="label2", head_name="head2") n_classes=4, label_name="label2", head_name="head2")
head = head_lib._multi_head((head1, head2), (1, .5)) head = head_lib.multi_head((head1, head2), (1, .5))
labels = { labels = {
"label1": (1,), "label1": (1,),
"label2": (1,) "label2": (1,)

View File

@ -419,7 +419,7 @@ class LinearClassifier(estimator.Estimator):
enable_centered_bias = False enable_centered_bias = False
logging.warning("centered_bias is not supported with SDCA, " logging.warning("centered_bias is not supported with SDCA, "
"please disable it explicitly.") "please disable it explicitly.")
head = head_lib._multi_class_head( # pylint: disable=protected-access head = head_lib.multi_class_head(
n_classes, n_classes,
weight_column_name=weight_column_name, weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias) enable_centered_bias=enable_centered_bias)
@ -686,7 +686,7 @@ class LinearRegressor(estimator.Estimator):
enable_centered_bias = False enable_centered_bias = False
logging.warning("centered_bias is not supported with SDCA, " logging.warning("centered_bias is not supported with SDCA, "
"please disable it explicitly.") "please disable it explicitly.")
head = head_lib._regression_head( # pylint: disable=protected-access head = head_lib.regression_head(
weight_column_name=weight_column_name, weight_column_name=weight_column_name,
label_dimension=label_dimension, label_dimension=label_dimension,
enable_centered_bias=enable_centered_bias) enable_centered_bias=enable_centered_bias)
@ -824,8 +824,7 @@ class LinearRegressor(estimator.Estimator):
exports_to_keep=exports_to_keep) exports_to_keep=exports_to_keep)
# TODO(zakaria): Make it public when b/34751732 is fixed. class LinearEstimator(estimator.Estimator):
class _LinearEstimator(estimator.Estimator):
"""Linear model with user specified head. """Linear model with user specified head.
Train a generalized linear model to predict label value given observation of Train a generalized linear model to predict label value given observation of
@ -840,9 +839,9 @@ class _LinearEstimator(estimator.Estimator):
sparse_feature_a_x_sparse_feature_b = crossed_column(...) sparse_feature_a_x_sparse_feature_b = crossed_column(...)
estimator = _LinearEstimator( estimator = LinearEstimator(
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b], feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
head=head_lib._poisson_regression_head()) head=head_lib.poisson_regression_head())
# Input builders # Input builders
def input_fn_train: # returns x, y def input_fn_train: # returns x, y
@ -879,7 +878,7 @@ class _LinearEstimator(estimator.Estimator):
_joint_weights=False, _joint_weights=False,
config=None, config=None,
feature_engineering_fn=None): feature_engineering_fn=None):
"""Construct a `_LinearEstimator` object. """Construct a `LinearEstimator` object.
Args: Args:
feature_columns: An iterable containing all the feature columns used by feature_columns: An iterable containing all the feature columns used by
@ -907,14 +906,14 @@ class _LinearEstimator(estimator.Estimator):
into the model. into the model.
Returns: Returns:
A `_LinearEstimator` estimator. A `LinearEstimator` estimator.
Raises: Raises:
ValueError: if optimizer is not supported, e.g., SDCAOptimizer ValueError: if optimizer is not supported, e.g., SDCAOptimizer
""" """
assert feature_columns assert feature_columns
if isinstance(optimizer, sdca_optimizer.SDCAOptimizer): if isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
raise ValueError("_LinearEstimator does not support SDCA optimizer.") raise ValueError("LinearEstimator does not support SDCA optimizer.")
params = { params = {
"head": head, "head": head,
@ -923,7 +922,7 @@ class _LinearEstimator(estimator.Estimator):
"gradient_clip_norm": gradient_clip_norm, "gradient_clip_norm": gradient_clip_norm,
"joint_weights": _joint_weights, "joint_weights": _joint_weights,
} }
super(_LinearEstimator, self).__init__( super(LinearEstimator, self).__init__(
model_fn=_linear_model_fn, model_fn=_linear_model_fn,
model_dir=model_dir, model_dir=model_dir,
config=config, config=config,

View File

@ -1665,15 +1665,15 @@ class LinearEstimatorTest(test.TestCase):
'feature', dimension=4) 'feature', dimension=4)
] ]
exp = experiment.Experiment( exp = experiment.Experiment(
estimator=linear._LinearEstimator(feature_columns=cont_features, estimator=linear.LinearEstimator(feature_columns=cont_features,
head=head_lib._regression_head()), head=head_lib.regression_head()),
train_input_fn=test_data.iris_input_logistic_fn, train_input_fn=test_data.iris_input_logistic_fn,
eval_input_fn=test_data.iris_input_logistic_fn) eval_input_fn=test_data.iris_input_logistic_fn)
exp.test() exp.test()
def testEstimatorContract(self): def testEstimatorContract(self):
estimator_test_utils.assert_estimator_contract(self, estimator_test_utils.assert_estimator_contract(self,
linear._LinearEstimator) linear.LinearEstimator)
def testLinearRegression(self): def testLinearRegression(self):
"""Tests that loss goes down with training.""" """Tests that loss goes down with training."""
@ -1691,8 +1691,8 @@ class LinearEstimatorTest(test.TestCase):
100) 100)
age = feature_column_lib.real_valued_column('age') age = feature_column_lib.real_valued_column('age')
linear_estimator = linear._LinearEstimator(feature_columns=[age, language], linear_estimator = linear.LinearEstimator(feature_columns=[age, language],
head=head_lib._regression_head()) head=head_lib.regression_head())
linear_estimator.fit(input_fn=input_fn, steps=100) linear_estimator.fit(input_fn=input_fn, steps=100)
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss'] loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
linear_estimator.fit(input_fn=input_fn, steps=400) linear_estimator.fit(input_fn=input_fn, steps=400)
@ -1717,9 +1717,9 @@ class LinearEstimatorTest(test.TestCase):
100) 100)
age = feature_column_lib.real_valued_column('age') age = feature_column_lib.real_valued_column('age')
linear_estimator = linear._LinearEstimator( linear_estimator = linear.LinearEstimator(
feature_columns=[age, language], feature_columns=[age, language],
head=head_lib._poisson_regression_head()) head=head_lib.poisson_regression_head())
linear_estimator.fit(input_fn=input_fn, steps=10) linear_estimator.fit(input_fn=input_fn, steps=10)
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss'] loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
linear_estimator.fit(input_fn=input_fn, steps=100) linear_estimator.fit(input_fn=input_fn, steps=100)
@ -1736,8 +1736,8 @@ class LinearEstimatorTest(test.TestCase):
sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
example_id_column='example_id') example_id_column='example_id')
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
linear._LinearEstimator( linear.LinearEstimator(
head=head_lib._regression_head(label_dimension=1), head=head_lib.regression_head(label_dimension=1),
feature_columns=[maintenance_cost, sq_footage], feature_columns=[maintenance_cost, sq_footage],
optimizer=sdca_optimizer, optimizer=sdca_optimizer,
_joint_weights=True) _joint_weights=True)

View File

@ -139,7 +139,7 @@ class SVM(estimator.Estimator):
model_dir=model_dir, model_dir=model_dir,
config=config, config=config,
params={ params={
"head": head_lib._binary_svm_head( # pylint: disable=protected-access "head": head_lib.binary_svm_head(
weight_column_name=weight_column_name, weight_column_name=weight_column_name,
enable_centered_bias=False), enable_centered_bias=False),
"feature_columns": feature_columns, "feature_columns": feature_columns,