diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 6016d93ec17..8289cea185a 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -28,13 +28,24 @@ See the @{$python/contrib.learn} guide. @@MetricSpec @@PredictionKey @@DNNClassifier +@@DNNEstimator @@DNNRegressor @@DNNLinearCombinedRegressor @@DNNLinearCombinedClassifier @@LinearClassifier +@@LinearEstimator @@LinearRegressor @@LogisticRegressor +@@Head +@@multi_class_head +@@multi_label_head +@@binary_svm_head +@@regression_head +@@poisson_regression_head +@@multi_head +@@no_op_train_fn + @@Experiment @@ExportStrategy @@TaskType diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 8a5cfd321a7..8b385c93dc2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -295,6 +295,7 @@ from __future__ import print_function from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier +from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNEstimator from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedRegressor @@ -304,8 +305,17 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat +from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head +from tensorflow.contrib.learn.python.learn.estimators.head import Head +from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head +from tensorflow.contrib.learn.python.learn.estimators.head import multi_head +from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head +from tensorflow.contrib.learn.python.learn.estimators.head import no_op_train_fn +from tensorflow.contrib.learn.python.learn.estimators.head import poisson_regression_head +from tensorflow.contrib.learn.python.learn.estimators.head import regression_head from tensorflow.contrib.learn.python.learn.estimators.kmeans import KMeansClustering from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier +from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py index d5de734d2bc..14750961efa 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py @@ -131,7 +131,7 @@ class ComposableModelTest(test.TestCase): language = feature_column.sparse_column_with_hash_bucket('language', 100) age = feature_column.real_valued_column('age') - head = head_lib._multi_class_head(n_classes=2) + head = head_lib.multi_class_head(n_classes=2) classifier = _linear_estimator(head, feature_columns=[age, language]) classifier.fit(input_fn=input_fn, steps=1000) @@ -157,7 +157,7 @@ class ComposableModelTest(test.TestCase): language = feature_column.sparse_column_with_hash_bucket('language', 100) age = feature_column.sparse_column_with_hash_bucket('age', 2) - head = head_lib._multi_class_head(n_classes=2) + head = head_lib.multi_class_head(n_classes=2) classifier = _joint_linear_estimator(head, feature_columns=[age, language]) classifier.fit(input_fn=input_fn, steps=1000) @@ -171,7 +171,7 @@ class ComposableModelTest(test.TestCase): """Tests multi-class classification using matrix data as input.""" cont_features = [feature_column.real_valued_column('feature', dimension=4)] - head = head_lib._multi_class_head(n_classes=3) + head = head_lib.multi_class_head(n_classes=3) classifier = _dnn_estimator( head, feature_columns=cont_features, hidden_units=[3, 3]) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index f9ba6711e69..86f780034e1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -304,7 +304,7 @@ class DNNClassifier(estimator.Estimator): config=config, params={ "head": - head_lib._multi_class_head( # pylint: disable=protected-access + head_lib.multi_class_head( n_classes, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias), @@ -579,7 +579,7 @@ class DNNRegressor(estimator.Estimator): config=config, params={ "head": - head_lib._regression_head( # pylint: disable=protected-access + head_lib.regression_head( label_dimension=label_dimension, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias), @@ -731,8 +731,7 @@ class DNNRegressor(estimator.Estimator): exports_to_keep=exports_to_keep) -# TODO(zakaria): Make it public when b/34751732 is fixed. -class _DNNEstimator(estimator.Estimator): +class DNNEstimator(estimator.Estimator): """A Estimator for TensorFlow DNN models with user specified _Head. Example: @@ -745,20 +744,20 @@ class _DNNEstimator(estimator.Estimator): ...) sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b, ...) - To create a _DNNEstimator for binary classification, where - estimator = _DNNEstimator( + To create a DNNEstimator for binary classification, where + estimator = DNNEstimator( feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], - head=head_lib._multi_class__head(n_classes=2), + head=tf.contrib.learn.multi_class_head(n_classes=2), hidden_units=[1024, 512, 256]) If your label is keyed with "y" in your labels dict, and weights are keyed with "w" in features dict, and you want to enable centered bias, - head = head_lib._multi_class__head( + head = tf.contrib.learn.multi_class_head( n_classes=2, label_name="x", weight_column_name="w", enable_centered_bias=True) - estimator = _DNNEstimator( + estimator = DNNEstimator( feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], head=head, hidden_units=[1024, 512, 256]) @@ -802,10 +801,10 @@ class _DNNEstimator(estimator.Estimator): feature_engineering_fn=None, embedding_lr_multipliers=None, input_layer_min_slice_size=None): - """Initializes a _DNNEstimator instance. + """Initializes a `DNNEstimator` instance. Args: - head: _Head instance. + head: `Head` instance. hidden_units: List of hidden units per layer. All layers are fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second one has 32. @@ -836,9 +835,9 @@ class _DNNEstimator(estimator.Estimator): partitions. If not provided, will use the default of 64M. Returns: - A `_DNNEstimator` estimator. + A `DNNEstimator` estimator. """ - super(_DNNEstimator, self).__init__( + super(DNNEstimator, self).__init__( model_fn=_dnn_model_fn, model_dir=model_dir, config=config, @@ -854,4 +853,3 @@ class _DNNEstimator(estimator.Estimator): "input_layer_min_slice_size": input_layer_min_slice_size, }, feature_engineering_fn=feature_engineering_fn) - diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index e5da8c00828..ec66e2ad2f8 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -550,7 +550,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator): if not self._feature_columns: raise ValueError("Either linear_feature_columns or dnn_feature_columns " "must be defined.") - head = head_lib._multi_class_head( # pylint: disable=protected-access + head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias) @@ -841,7 +841,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator): if not self._feature_columns: raise ValueError("Either linear_feature_columns or dnn_feature_columns " "must be defined.") - head = head_lib._regression_head( # pylint: disable=protected-access + head = head_lib.regression_head( weight_column_name=weight_column_name, label_dimension=label_dimension, enable_centered_bias=enable_centered_bias) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 380e75b693c..a8600c31947 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -60,7 +60,7 @@ def _assert_metrics_in_range(keys, metrics): metrics) -class _CheckCallsHead(head_lib._Head): # pylint: disable=protected-access +class _CheckCallsHead(head_lib.Head): """Head that checks whether head_ops is called.""" def __init__(self): @@ -97,7 +97,7 @@ class EmbeddingMultiplierTest(test.TestCase): params = { 'dnn_feature_columns': [one_hot_language], - 'head': head_lib._multi_class_head(2), + 'head': head_lib.multi_class_head(2), 'dnn_hidden_units': [1], # Set lr mult to 0. to keep embeddings constant. 'embedding_lr_multipliers': { @@ -131,7 +131,7 @@ class EmbeddingMultiplierTest(test.TestCase): params = { 'dnn_feature_columns': [embedding_language, embedding_wire], - 'head': head_lib._multi_class_head(2), + 'head': head_lib.multi_class_head(2), 'dnn_hidden_units': [1], # Set lr mult to 0. to keep embeddings constant. 'embedding_lr_multipliers': { diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 465548fbf2b..c897b6c9d58 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -59,7 +59,7 @@ class EmbeddingMultiplierTest(test.TestCase): params = { 'feature_columns': [one_hot_language], - 'head': head_lib._multi_class_head(2), + 'head': head_lib.multi_class_head(2), 'hidden_units': [1], # Set lr mult to 0. to keep embeddings constant. 'embedding_lr_multipliers': { @@ -90,7 +90,7 @@ class EmbeddingMultiplierTest(test.TestCase): params = { 'feature_columns': [embedding_language, embedding_wire], - 'head': head_lib._multi_class_head(2), + 'head': head_lib.multi_class_head(2), 'hidden_units': [1], # Set lr mult to 0. to keep embeddings constant. 'embedding_lr_multipliers': { @@ -145,7 +145,7 @@ class DNNEstimatorTest(test.TestCase): exp.test() def testEstimatorContract(self): - estimator_test_utils.assert_estimator_contract(self, dnn._DNNEstimator) + estimator_test_utils.assert_estimator_contract(self, dnn.DNNEstimator) def testTrainWithWeights(self): """Tests training with given weight column.""" @@ -172,8 +172,8 @@ class DNNEstimatorTest(test.TestCase): } return features, labels - dnn_estimator = dnn._DNNEstimator( - head=head_lib._multi_class_head(2, weight_column_name='w'), + dnn_estimator = dnn.DNNEstimator( + head=head_lib.multi_class_head(2, weight_column_name='w'), feature_columns=[feature_column.real_valued_column('x')], hidden_units=[3, 3], config=run_config.RunConfig(tf_random_seed=1)) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 952cdeb5ec1..00cc23ce859 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -46,15 +46,141 @@ from tensorflow.python.ops import variables from tensorflow.python.summary import summary from tensorflow.python.training import training -# TODO(zakaria): add functions that creates a head and returns ModelOpFn + +class Head(object): + """Interface for the head/top of a model. + + Given logits (or output of a hidden layer), a Head knows how to compute + predictions, loss, default metric and export signature. It is meant to, + + 1) Simplify writing model_fn and to make model_fn more configurable + 2) Support wide range of machine learning models. Since most heads can work + with logits, they can support DNN, RNN, Wide, Wide&Deep, + Global objectives, Gradient boosted trees and many other types + of machine learning models. + 2) To allow users to seamlessly switch between 1 to n heads for multi + objective learning (See _MultiHead implementation for more details) + + Common usage: + Here is simplified model_fn to build a multiclass DNN model. + ```python + def _my_dnn_model_fn(features, labels, mode, params, config=None): + # Optionally your callers can pass head to model_fn as a param. + head = tf.contrib.learn.multi_class_head(...) + input = tf.contrib.layers.input_from_feature_columns(features, ...) + last_hidden_layer_out = tf.contrib.layers.stack( + input, tf.contrib.layers.fully_connected, [1000, 500]) + logits = tf.contrib.layers.fully_connected( + last_hidden_layer_out, head.logits_dimension, activation_fn=None) + + def _train_op_fn(loss): + return optimizer.minimize(loss) + + return head.create_model_fn_ops( + features=features, + labels=labels, + mode=mode, + train_op_fn=_train_op_fn, + logits=logits, + scope=...) + ``` + + Most heads also support logits_input which is typically the output of the last + hidden layer. Some heads (like heads responsible for candidate sampling or + hierarchical softmax) intrinsically will not support logits and you have + to pass logits_input. Here is a common usage, + ```python + return head.create_model_fn_ops( + features=features, + labels=labels, + mode=mode, + train_op_fn=_train_op_fn, + logits_input=last_hidden_layer_out, + scope=...) + ```python + + There are cases where computing and applying gradients can not be meaningfully + captured with train_op_fn we support (for example, with sync optimizer). In + such case, you can take the responsibility on your own. Here is a common + use case, + ```python + model_fn_ops = head.create_model_fn_ops( + features=features, + labels=labels, + mode=mode, + train_op_fn=tf.contrib.learn.no_op_train_fn, + logits=logits, + scope=...) + if mode == tf.contrib.learn.ModeKeys.TRAIN: + optimizer = ... + sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...) + update_op = tf.contrib.layers.optimize_loss(optimizer=sync, + loss=model_fn_ops.loss, ...) + hooks = [sync.make_session_run_hook(is_chief)] + ... upate train_op and hooks in ModelFnOps and return + ``` + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractproperty + def logits_dimension(self): + """Size of the last dimension of the logits `Tensor`. + + Typically, logits is of shape `[batch_size, logits_dimension]`. + + Returns: + The expected size of the `logits` tensor. + """ + raise NotImplementedError("Calling an abstract method.") + + @abc.abstractmethod + def create_model_fn_ops(self, + features, + mode, + labels=None, + train_op_fn=None, + logits=None, + logits_input=None, + scope=None): + """Returns `ModelFnOps` that a model_fn can return. + + Please note that, + + Exactly one of `logits` and `logits_input` must be provided. + + All args must be passed via name. + + Args: + features: Input `dict` of `Tensor` objects. + mode: Estimator's `ModeKeys`. + labels: Labels `Tensor`, or `dict` of same. + train_op_fn: Function that takes a scalar loss `Tensor` and returns an op + to optimize the model with the loss. This is used in TRAIN mode and + must not be None. None is allowed in other modes. If you want to + optimize loss yourself you can pass `no_op_train_fn` and then use + ModeFnOps.loss to compute and apply gradients. + logits: logits `Tensor` to be used by the head. + logits_input: `Tensor` from which to build logits, often needed when you + don't want to compute the logits. Typicaly this is the activation of the + last hidden layer in a DNN. Some heads (like the ones responsible for + candidate sampling) intrinsically avoid computing full logits and only + accepts logits_input. + scope: Optional scope for `variable_scope`. + + Returns: + An instance of `ModelFnOps`. + + Raises: + ValueError: If `mode` is not recognized. + ValueError: If neither or both of `logits` and `logits_input` is provided. + """ + raise NotImplementedError("Calling an abstract method.") -def _regression_head(label_name=None, - weight_column_name=None, - label_dimension=1, - enable_centered_bias=False, - head_name=None): - """Creates a _Head for linear regression. +def regression_head(label_name=None, + weight_column_name=None, + label_dimension=1, + enable_centered_bias=False, + head_name=None): + """Creates a `Head` for linear regression. Args: label_name: String, name of the key in label dict. Can be null if label @@ -73,7 +199,7 @@ def _regression_head(label_name=None, will be `head_name`. Returns: - An instance of _Head + An instance of `Head` for linear regression. """ return _RegressionHead( label_name=label_name, @@ -85,12 +211,12 @@ def _regression_head(label_name=None, link_fn=array_ops.identity) -def _poisson_regression_head(label_name=None, - weight_column_name=None, - label_dimension=1, - enable_centered_bias=False, - head_name=None): - """Creates a _Head for linear regression. +def poisson_regression_head(label_name=None, + weight_column_name=None, + label_dimension=1, + enable_centered_bias=False, + head_name=None): + """Creates a `Head` for poisson regression. Args: label_name: String, name of the key in label dict. Can be null if label @@ -109,7 +235,7 @@ def _poisson_regression_head(label_name=None, will be `head_name`. Returns: - An instance of _Head + An instance of `Head` for poisson regression. """ return _RegressionHead( label_name=label_name, @@ -120,18 +246,18 @@ def _poisson_regression_head(label_name=None, loss_fn=_poisson_loss, link_fn=math_ops.exp) -# TODO(zakaria): Add logistic_regression_head +# TODO(zakaria): Consider adding a _RegressionHead for logistic_regression -def _multi_class_head(n_classes, - label_name=None, - weight_column_name=None, - enable_centered_bias=False, - head_name=None, - thresholds=None, - metric_class_ids=None, - loss_fn=None): - """Creates a _Head for multi class single label classification. +def multi_class_head(n_classes, + label_name=None, + weight_column_name=None, + enable_centered_bias=False, + head_name=None, + thresholds=None, + metric_class_ids=None, + loss_fn=None): + """Creates a `Head` for multi class single label classification. The Head uses softmax cross entropy loss. @@ -157,7 +283,7 @@ def _multi_class_head(n_classes, optional. See `tf.losses` Returns: - An instance of _MultiClassHead. + An instance of `Head` for multi class classification. Raises: ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when @@ -193,13 +319,13 @@ def _multi_class_head(n_classes, loss_fn=loss_fn) -def _binary_svm_head( +def binary_svm_head( label_name=None, weight_column_name=None, enable_centered_bias=False, head_name=None, thresholds=None,): - """Creates a `_Head` for binary classification with SVMs. + """Creates a `Head` for binary classification with SVMs. The head uses binary hinge loss. @@ -218,8 +344,7 @@ def _binary_svm_head( thresholds: thresholds for eval metrics, defaults to [.5] Returns: - An instance of `_Head`. - + An instance of `Head` for binary classification with SVM. """ return _BinarySvmHead( label_name=label_name, @@ -229,15 +354,15 @@ def _binary_svm_head( thresholds=thresholds) -def _multi_label_head(n_classes, - label_name=None, - weight_column_name=None, - enable_centered_bias=False, - head_name=None, - thresholds=None, - metric_class_ids=None, - loss_fn=None): - """Creates a _Head for multi label classification. +def multi_label_head(n_classes, + label_name=None, + weight_column_name=None, + enable_centered_bias=False, + head_name=None, + thresholds=None, + metric_class_ids=None, + loss_fn=None): + """Creates a Head for multi label classification. The Head uses sigmoid cross entropy loss. @@ -262,7 +387,7 @@ def _multi_label_head(n_classes, optional. See `tf.losses` Returns: - An instance of _MultiLabelHead. + An instance of `Head` for multi label classification. Raises: ValueError: If n_classes is < 2 @@ -284,16 +409,16 @@ def _multi_label_head(n_classes, loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) -def _multi_head(heads, loss_weights=None): +def multi_head(heads, loss_weights=None): """Creates a MultiHead stemming from same logits/hidden layer. Args: - heads: list of _Head objects. - loss_weights: optional list of weights to be used to combine losses from + heads: list of Head objects. + loss_weights: optional list of weights to be used to merge losses from each head. All losses are weighted equally if not provided. Returns: - A _Head instance that combines multiple heads. + A instance of `Head` that merges multiple heads. Raises: ValueError: if heads and loss_weights have different size. @@ -302,7 +427,7 @@ def _multi_head(heads, loss_weights=None): if len(loss_weights) != len(heads): raise ValueError("heads and loss_weights must have same size") - def _weighted_loss_combiner(losses): + def _weighted_loss_merger(losses): if loss_weights: if len(losses) != len(loss_weights): raise ValueError("losses and loss_weights must have same size") @@ -313,7 +438,7 @@ def _multi_head(heads, loss_weights=None): else: return math_ops.add_n(losses) - return _MultiHead(heads, loss_combiner=_weighted_loss_combiner) + return _MultiHead(heads, loss_merger=_weighted_loss_merger) def no_op_train_fn(loss): @@ -321,64 +446,7 @@ def no_op_train_fn(loss): return control_flow_ops.no_op() -# TODO(zakaria): Make the classes public once we are ready for users to subclass -# them. See b/34751732 -class _Head(object): - """Interface for the head/top of a model. - - Given logits or output of a hidden layer, a Head knows how to compute - predictions, loss, default metric and export signature. - """ - __metaclass__ = abc.ABCMeta - - @abc.abstractproperty - def logits_dimension(self): - """Size of the last dimension of the logits `Tensor`. - - Typically, logits is of shape `[batch_size, logits_dimension]`. - - Returns: - Number of logits values per example. - """ - raise NotImplementedError("Calling an abstract method.") - - @abc.abstractmethod - def create_model_fn_ops(self, - features, - mode, - labels=None, - train_op_fn=None, - logits=None, - logits_input=None, - scope=None): - """Returns ops for a model_fn. - - Exactly one of `logits` and `logits_input` must be provided. - - All args must be passed via name. - - Args: - features: Input `dict` of `Tensor` objects. - mode: Estimator's `ModeKeys`. - labels: Labels `Tensor`, or `dict` of same. - train_op_fn: Function that takes a scalar loss and returns an op to - optimize with the loss. Must not be `None` in TRAIN mode. If you want - to optimize loss yourself you can pass `no_op_train_fn`. - logits: logits `Tensor`, or `dict` of same, to be used for the head. - logits_input: `Tensor` from which to build logits. - scope: Optional scope for `variable_scope`. - - Returns: - `ModelFnOps`. - - Raises: - ValueError: if `mode` is not recognized, or neither or both of `logits` - and `logits_input` is provided. - """ - raise NotImplementedError("Calling an abstract method.") - - -class _SingleHead(_Head): +class _SingleHead(Head): """Interface for a single head/top of a model.""" __metaclass__ = abc.ABCMeta @@ -565,7 +633,7 @@ def _create_model_fn_ops(features, class _RegressionHead(_SingleHead): - """_Head for regression with a generalized linear model.""" + """`Head` for regression with a generalized linear model.""" def __init__(self, label_dimension, @@ -575,7 +643,7 @@ class _RegressionHead(_SingleHead): weight_column_name=None, enable_centered_bias=False, head_name=None): - """Head for regression. + """`Head` for regression. Args: label_dimension: Number of regression labels per example. This is the @@ -614,7 +682,7 @@ class _RegressionHead(_SingleHead): logits=None, logits_input=None, scope=None): - """See `_Head`.""" + """See `Head`.""" return _create_model_fn_ops( features=features, mode=mode, @@ -682,7 +750,7 @@ def _one_class_to_two_class_logits(logits): class _BinaryLogisticHead(_SingleHead): - """_Head for binary logistic classifciation.""" + """`Head` for binary classification with logistic regression.""" def __init__(self, label_name=None, @@ -691,7 +759,7 @@ class _BinaryLogisticHead(_SingleHead): head_name=None, loss_fn=None, thresholds=None): - """Base type for all single heads. + """`Head` for binary classification with logistic regression. Args: label_name: String, name of the key in label dict. Can be `None` if label @@ -729,7 +797,7 @@ class _BinaryLogisticHead(_SingleHead): logits=None, logits_input=None, scope=None): - """See `_Head`.""" + """See `Head`.""" return _create_model_fn_ops( features=features, mode=mode, @@ -844,7 +912,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None): class _MultiClassHead(_SingleHead): - """_Head for classification.""" + """'Head' for multi class classification.""" def __init__(self, n_classes, @@ -855,7 +923,7 @@ class _MultiClassHead(_SingleHead): loss_fn=None, thresholds=None, metric_class_ids=None): - """_Head for classification. + """'Head' for multi class classification. Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use @@ -905,7 +973,7 @@ class _MultiClassHead(_SingleHead): logits=None, logits_input=None, scope=None): - """See `_Head`.""" + """See `Head`.""" return _create_model_fn_ops( features=features, mode=mode, @@ -1039,7 +1107,7 @@ def _assert_labels_rank(labels): class _BinarySvmHead(_SingleHead): - """_Head for binary classification using SVMs.""" + """`Head` for binary classification using SVM.""" def __init__(self, label_name, weight_column_name, enable_centered_bias, head_name, thresholds): @@ -1069,7 +1137,7 @@ class _BinarySvmHead(_SingleHead): logits=None, logits_input=None, scope=None): - """See `_Head`.""" + """See `Head`.""" return _create_model_fn_ops( features=features, mode=mode, @@ -1125,7 +1193,7 @@ class _BinarySvmHead(_SingleHead): class _MultiLabelHead(_SingleHead): - """_Head for multlabel classification.""" + """`Head` for multi-label classification.""" # TODO(zakaria): add signature and metric for multilabel. def __init__(self, @@ -1162,7 +1230,7 @@ class _MultiLabelHead(_SingleHead): logits=None, logits_input=None, scope=None): - """See `_Head`.""" + """See `Head`.""" return _create_model_fn_ops( features=features, mode=mode, @@ -1240,24 +1308,52 @@ class _MultiLabelHead(_SingleHead): return metrics -class _MultiHead(_Head): - """_Head to combine multiple _Head objects. +class _MultiHead(Head): + """`Head` implementation for multi objective learning. + + This class is responsible for using and merging the output of multiple + `Head` objects. All heads stem from the same logits/logit_input tensor. - For training, combines losses of each heads according a function provided by - user. - For eval, adds a /head_name suffix to the keys in eval metrics. - For inference, updates keys prediction dict to a 2-tuple, - (head_name, prediction_key) + Common usage: + For simple use cases you can pass the activation of hidden layer like + this from your model_fn, + ```python + last_hidden_layer_activation = ... Build your model. + multi_head = ... + return multi_head.create_model_fn_ops( + ..., logits_input=last_hidden_layer_activation, ...) + ``` + + Or you can create a logits tensor of + [batch_size, multi_head.logits_dimension] shape. _MultiHead will split the + logits for you. + return multi_head.create_model_fn_ops(..., logits=logits, ...) + + For more complex use cases like a multi-task/multi-tower model or when logits + for each head has to be created separately, you can pass a dict of logits + where the keys match the name of the single heads. + ```python + logits = {"head1": logits1, "head2": logits2} + return multi_head.create_model_fn_ops(..., logits=logits, ...) + ``` + + Here is what this class does, + + For training, merges losses of each heads according a function provided by + user, calls user provided train_op_fn with this final loss. + + For eval, merges metrics by adding head_name suffix to the keys in eval + metrics. + + For inference, updates keys in prediction dict to a 2-tuple, + (head_name, prediction_key) """ - def __init__(self, heads, loss_combiner): - """_Head to combine multiple _Head objects. + def __init__(self, heads, loss_merger): + """_Head to merges multiple _Head objects. Args: heads: list of _Head objects. - loss_combiner: function that takes a list of loss tensors for the heads + loss_merger: function that takes a list of loss tensors for the heads and returns the final loss tensor for the multi head. Raises: @@ -1274,7 +1370,7 @@ class _MultiHead(_Head): self._logits_dimension += head.logits_dimension self._heads = heads - self._loss_combiner = loss_combiner + self._loss_merger = loss_merger @property def logits_dimension(self): @@ -1353,11 +1449,11 @@ class _MultiHead(_Head): if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError("train_op_fn can not be None in TRAIN mode.") - return self._combine_train(all_model_fn_ops, train_op_fn) + return self._merge_train(all_model_fn_ops, train_op_fn) if mode == model_fn.ModeKeys.INFER: - return self._combine_infer(all_model_fn_ops) + return self._merge_infer(all_model_fn_ops) if mode == model_fn.ModeKeys.EVAL: - return self._combine_eval(all_model_fn_ops) + return self._merge_eval(all_model_fn_ops) raise ValueError("mode=%s unrecognized" % str(mode)) def _split_logits(self, logits): @@ -1379,8 +1475,8 @@ class _MultiHead(_Head): begin += current_logits_size return all_logits - def _combine_train(self, all_model_fn_ops, train_op_fn): - """Combines list of ModelFnOps for training. + def _merge_train(self, all_model_fn_ops, train_op_fn): + """Merges list of ModelFnOps for training. Args: all_model_fn_ops: list of ModelFnOps for the individual heads. @@ -1388,14 +1484,14 @@ class _MultiHead(_Head): documentaion for more details. Returns: - ModelFnOps that combines all the heads. + ModelFnOps that merges all heads for TRAIN. """ losses = [] additional_train_ops = [] for m in all_model_fn_ops: losses.append(m.loss) additional_train_ops.append(m.train_op) - loss = self._loss_combiner(losses) + loss = self._loss_merger(losses) train_op = train_op_fn(loss) train_op = control_flow_ops.group(train_op, *additional_train_ops) @@ -1404,14 +1500,14 @@ class _MultiHead(_Head): loss=loss, train_op=train_op) - def _combine_infer(self, all_model_fn_ops): - """Combines list of ModelFnOps for inference. + def _merge_infer(self, all_model_fn_ops): + """Merges list of ModelFnOps for inference. Args: all_model_fn_ops: list of ModelFnOps for the individual heads. Returns: - ModelFnOps that combines all the heads. + ModelFnOps that Merges all the heads for INFER. """ predictions = {} output_alternatives = {} @@ -1426,14 +1522,14 @@ class _MultiHead(_Head): predictions=predictions, output_alternatives=output_alternatives) - def _combine_eval(self, all_model_fn_ops): - """Combines list of ModelFnOps for eval. + def _merge_eval(self, all_model_fn_ops): + """Merges list of ModelFnOps for eval. Args: all_model_fn_ops: list of ModelFnOps for the individual heads. Returns: - ModelFnOps that combines all the heads. + ModelFnOps that merges all the heads for EVAL. """ predictions = {} metrics = {} @@ -1446,7 +1542,7 @@ class _MultiHead(_Head): for k, v in m.eval_metric_ops.items(): # metrics["%s/%s" % (k, head_name)] = v metrics[k] = v - loss = self._loss_combiner(losses) + loss = self._loss_merger(losses) return model_fn.ModelFnOps( mode=model_fn.ModeKeys.EVAL, @@ -1733,3 +1829,14 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold): predictions, labels=labels, thresholds=(threshold,), weights=_float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) + + +# Aliases +# TODO(zakaria): Remove these aliases, See b/34751732 +_regression_head = regression_head +_poisson_regression_head = poisson_regression_head +_multi_class_head = multi_class_head +_binary_svm_head = binary_svm_head +_multi_label_head = multi_label_head +_multi_head = multi_head +_Head = Head diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index faa3108caf3..cc050ba5362 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -112,7 +112,7 @@ class PoissonHeadTest(test.TestCase): return sum(lpl)/len(lpl) def testPoissonWithLogits(self): - head = head_lib._poisson_regression_head() + head = head_lib.poisson_regression_head() labels = ((0.,), (1.,), (1.,)) logits = ((0.,), (-1.,), (3.,)) with ops.Graph().as_default(), session.Session(): @@ -140,7 +140,7 @@ class RegressionHeadTest(test.TestCase): # TODO(zakaria): test multilabel regression. def testRegressionWithLogits(self): - head = head_lib._regression_head() + head = head_lib.regression_head() with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -154,7 +154,7 @@ class RegressionHeadTest(test.TestCase): _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionWithInvalidLogits(self): - head = head_lib._regression_head() + head = head_lib.regression_head() with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): head.create_model_fn_ops( @@ -165,7 +165,7 @@ class RegressionHeadTest(test.TestCase): logits=((1., 1.), (1., 1.), (3., 1.))) def testRegressionWithLogitsInput(self): - head = head_lib._regression_head() + head = head_lib.regression_head() with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -183,7 +183,7 @@ class RegressionHeadTest(test.TestCase): _assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops) def testRegressionWithLogitsAndLogitsInput(self): - head = head_lib._regression_head() + head = head_lib.regression_head() with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp( ValueError, "Both logits and logits_input supplied"): @@ -196,7 +196,7 @@ class RegressionHeadTest(test.TestCase): logits=((1.,), (1.,), (3.,))) def testRegressionEvalMode(self): - head = head_lib._regression_head() + head = head_lib.regression_head() with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -212,7 +212,7 @@ class RegressionHeadTest(test.TestCase): def testRegressionWithLabelName(self): label_name = "my_label" - head = head_lib._regression_head(label_name=label_name) + head = head_lib.regression_head(label_name=label_name) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -226,7 +226,7 @@ class RegressionHeadTest(test.TestCase): _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionWithWeights(self): - head = head_lib._regression_head(weight_column_name="label_weight") + head = head_lib.regression_head(weight_column_name="label_weight") with ops.Graph().as_default(), session.Session(): weights = ((2.,), (5.,), (0.,)) model_fn_ops = head.create_model_fn_ops( @@ -242,7 +242,7 @@ class RegressionHeadTest(test.TestCase): model_fn_ops) def testRegressionWithCenteredBias(self): - head = head_lib._regression_head(enable_centered_bias=True) + head = head_lib.regression_head(enable_centered_bias=True) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -264,7 +264,7 @@ class RegressionHeadTest(test.TestCase): _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) def testRegressionErrorInSparseTensorLabels(self): - head = head_lib._regression_head() + head = head_lib.regression_head() with ops.Graph().as_default(): labels = sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (2, 0)), @@ -317,7 +317,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelWithLogits(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( @@ -334,7 +334,7 @@ class MultiLabelHeadTest(test.TestCase): n_classes = 2 labels = ((0, 1),) logits = ((1., 0.),) - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( @@ -361,7 +361,7 @@ class MultiLabelHeadTest(test.TestCase): }, model_fn_ops) def testMultiLabelWithInvalidLogits(self): - head = head_lib._multi_label_head(n_classes=len(self._labels[0]) + 1) + head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1) with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): head.create_model_fn_ops( @@ -370,7 +370,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelWithLogitsInput(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( @@ -407,7 +407,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelWithLogitsAndLogitsInput(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp( @@ -418,7 +418,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelEvalMode(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( @@ -434,7 +434,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiClassEvalModeWithLargeLogits(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) logits = ((2., 0., -1),) with ops.Graph().as_default(), session.Session(): @@ -474,7 +474,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelWithLabelName(self): n_classes = 3 label_name = "my_label" - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, label_name=label_name, metric_class_ids=range(n_classes)) @@ -491,7 +491,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelWithWeight(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, weight_column_name="label_weight", metric_class_ids=range(n_classes)) @@ -510,7 +510,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelWithCustomLoss(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, weight_column_name="label_weight", metric_class_ids=range(n_classes), @@ -530,7 +530,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelWithCenteredBias(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, enable_centered_bias=True, metric_class_ids=range(n_classes)) @@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelSparseTensorLabels(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): labels = sparse_tensor.SparseTensorValue( @@ -580,7 +580,7 @@ class MultiLabelHeadTest(test.TestCase): def testMultiLabelSparseTensorLabelsTooFewClasses(self): n_classes = 3 - head = head_lib._multi_label_head( + head = head_lib.multi_label_head( n_classes=n_classes, metric_class_ids=range(n_classes)) # Set _logits_dimension (n_classes) to a lower value; if it's set to 1 # upfront, the class throws an error during initialization. @@ -629,7 +629,7 @@ class BinaryClassificationHeadTest(test.TestCase): def testBinaryClassificationWithLogits(self): n_classes = 2 - head = head_lib._multi_class_head(n_classes=n_classes) + head = head_lib.multi_class_head(n_classes=n_classes) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) @@ -644,7 +644,7 @@ class BinaryClassificationHeadTest(test.TestCase): self._expected_eval_metrics(expected_loss), model_fn_ops) def testBinaryClassificationWithInvalidLogits(self): - head = head_lib._multi_class_head(n_classes=len(self._labels) + 1) + head = head_lib.multi_class_head(n_classes=len(self._labels) + 1) with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): head.create_model_fn_ops( @@ -653,7 +653,7 @@ class BinaryClassificationHeadTest(test.TestCase): def testBinaryClassificationWithLogitsInput(self): n_classes = 2 - head = head_lib._multi_class_head(n_classes=n_classes) + head = head_lib.multi_class_head(n_classes=n_classes) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) @@ -682,7 +682,7 @@ class BinaryClassificationHeadTest(test.TestCase): }, model_fn_ops) def testBinaryClassificationWithLogitsAndLogitsInput(self): - head = head_lib._multi_class_head(n_classes=2) + head = head_lib.multi_class_head(n_classes=2) with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp( ValueError, "Both logits and logits_input supplied"): @@ -692,7 +692,7 @@ class BinaryClassificationHeadTest(test.TestCase): def testBinaryClassificationEvalMode(self): n_classes = 2 - head = head_lib._multi_class_head(n_classes=n_classes) + head = head_lib.multi_class_head(n_classes=n_classes) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) @@ -709,7 +709,7 @@ class BinaryClassificationHeadTest(test.TestCase): def testBinaryClassificationInferMode(self): n_classes = 2 - head = head_lib._multi_class_head(n_classes=n_classes) + head = head_lib.multi_class_head(n_classes=n_classes) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) @@ -722,8 +722,8 @@ class BinaryClassificationHeadTest(test.TestCase): def testBinaryClassificationInferMode_withWightColumn(self): n_classes = 2 - head = head_lib._multi_class_head(n_classes=n_classes, - weight_column_name="label_weight") + head = head_lib.multi_class_head(n_classes=n_classes, + weight_column_name="label_weight") with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) @@ -738,7 +738,7 @@ class BinaryClassificationHeadTest(test.TestCase): def testErrorInSparseTensorLabels(self): n_classes = 2 - head = head_lib._multi_class_head(n_classes=n_classes) + head = head_lib.multi_class_head(n_classes=n_classes) with ops.Graph().as_default(): labels = sparse_tensor.SparseTensorValue( indices=((0, 0), (1, 0), (2, 0)), @@ -755,7 +755,7 @@ class BinaryClassificationHeadTest(test.TestCase): def testBinaryClassificationWithLabelName(self): label_name = "my_label" - head = head_lib._multi_class_head(n_classes=2, label_name=label_name) + head = head_lib.multi_class_head(n_classes=2, label_name=label_name) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) @@ -774,7 +774,7 @@ class BinaryClassificationHeadTest(test.TestCase): def testBinaryClassificationWithWeights(self): n_classes = 2 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name="label_weight") with ops.Graph().as_default(), session.Session(): weights = ((1.,), (0.,)) @@ -808,7 +808,7 @@ class BinaryClassificationHeadTest(test.TestCase): model_fn_ops) def testBinaryClassificationWithCustomLoss(self): - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=2, weight_column_name="label_weight", loss_fn=_sigmoid_cross_entropy) with ops.Graph().as_default(), session.Session(): @@ -844,7 +844,7 @@ class BinaryClassificationHeadTest(test.TestCase): model_fn_ops) def testBinaryClassificationWithCenteredBias(self): - head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True) + head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) @@ -904,7 +904,7 @@ class MultiClassHeadTest(test.TestCase): def testMultiClassWithLogits(self): n_classes = 3 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit @@ -920,7 +920,7 @@ class MultiClassHeadTest(test.TestCase): self._expected_eval_metrics(expected_loss), model_fn_ops) def testMultiClassWithInvalidLogits(self): - head = head_lib._multi_class_head(n_classes=len(self._logits[0]) + 1) + head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1) with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): head.create_model_fn_ops( @@ -928,7 +928,7 @@ class MultiClassHeadTest(test.TestCase): logits=self._logits) def testMultiClassWithNoneTrainOpFnInTrain(self): - head = head_lib._multi_class_head(n_classes=3) + head = head_lib.multi_class_head(n_classes=3) with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp( ValueError, "train_op_fn can not be None in TRAIN mode"): @@ -939,7 +939,7 @@ class MultiClassHeadTest(test.TestCase): def testMultiClassWithLogitsInput(self): n_classes = 3 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit @@ -978,7 +978,7 @@ class MultiClassHeadTest(test.TestCase): def testMultiClassWithLogitsAndLogitsInput(self): n_classes = 3 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp( @@ -989,7 +989,7 @@ class MultiClassHeadTest(test.TestCase): def testMultiClassEvalMode(self): n_classes = 3 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, metric_class_ids=range(n_classes)) with ops.Graph().as_default(), session.Session(): # logloss: z:label, x:logit @@ -1007,7 +1007,7 @@ class MultiClassHeadTest(test.TestCase): def testMultiClassEvalModeWithLargeLogits(self): n_classes = 3 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, metric_class_ids=range(n_classes)) logits = ((2., 0., -1),) with ops.Graph().as_default(), session.Session(): @@ -1046,7 +1046,7 @@ class MultiClassHeadTest(test.TestCase): def testMultiClassWithWeight(self): n_classes = 3 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name="label_weight", metric_class_ids=range(n_classes)) @@ -1069,7 +1069,7 @@ class MultiClassHeadTest(test.TestCase): def testMultiClassWithCustomLoss(self): n_classes = 3 - head = head_lib._multi_class_head( + head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name="label_weight", metric_class_ids=range(n_classes), @@ -1094,7 +1094,7 @@ class MultiClassHeadTest(test.TestCase): def testInvalidNClasses(self): for n_classes in (None, -1, 0, 1): with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"): - head_lib._multi_class_head(n_classes=n_classes) + head_lib.multi_class_head(n_classes=n_classes) class BinarySvmHeadTest(test.TestCase): @@ -1116,7 +1116,7 @@ class BinarySvmHeadTest(test.TestCase): self._expected_losses = (.5, 0.) def testBinarySVMWithLogits(self): - head = head_lib._binary_svm_head() + head = head_lib.binary_svm_head() with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -1134,7 +1134,7 @@ class BinarySvmHeadTest(test.TestCase): }, model_fn_ops) def testBinarySVMWithInvalidLogits(self): - head = head_lib._binary_svm_head() + head = head_lib.binary_svm_head() with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): head.create_model_fn_ops( @@ -1142,7 +1142,7 @@ class BinarySvmHeadTest(test.TestCase): logits=np.ones((2, 2))) def testBinarySVMWithLogitsInput(self): - head = head_lib._binary_svm_head() + head = head_lib.binary_svm_head() with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -1164,7 +1164,7 @@ class BinarySvmHeadTest(test.TestCase): }, model_fn_ops) def testBinarySVMWithLogitsAndLogitsInput(self): - head = head_lib._binary_svm_head() + head = head_lib.binary_svm_head() with ops.Graph().as_default(), session.Session(): with self.assertRaisesRegexp( ValueError, "Both logits and logits_input supplied"): @@ -1177,7 +1177,7 @@ class BinarySvmHeadTest(test.TestCase): logits=self._predictions) def testBinarySVMEvalMode(self): - head = head_lib._binary_svm_head() + head = head_lib.binary_svm_head() with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -1197,7 +1197,7 @@ class BinarySvmHeadTest(test.TestCase): def testBinarySVMWithLabelName(self): label_name = "my_label" - head = head_lib._binary_svm_head(label_name=label_name) + head = head_lib.binary_svm_head(label_name=label_name) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -1215,7 +1215,7 @@ class BinarySvmHeadTest(test.TestCase): }, model_fn_ops) def testBinarySVMWithWeights(self): - head = head_lib._binary_svm_head(weight_column_name="weights") + head = head_lib.binary_svm_head(weight_column_name="weights") with ops.Graph().as_default(), session.Session(): weights = (7., 11.) model_fn_ops = head.create_model_fn_ops( @@ -1235,7 +1235,7 @@ class BinarySvmHeadTest(test.TestCase): }, model_fn_ops) def testBinarySVMWithCenteredBias(self): - head = head_lib._binary_svm_head(enable_centered_bias=True) + head = head_lib.binary_svm_head(enable_centered_bias=True) with ops.Graph().as_default(), session.Session(): model_fn_ops = head.create_model_fn_ops( {}, @@ -1265,21 +1265,21 @@ class BinarySvmHeadTest(test.TestCase): class MultiHeadTest(test.TestCase): def testInvalidHeads(self): - named_head = head_lib._multi_class_head( + named_head = head_lib.multi_class_head( n_classes=3, label_name="label", head_name="head1") - unnamed_head = head_lib._multi_class_head( + unnamed_head = head_lib.multi_class_head( n_classes=4, label_name="label") with self.assertRaisesRegexp(ValueError, "must have names"): - head_lib._multi_head((named_head, unnamed_head)) + head_lib.multi_head((named_head, unnamed_head)) with self.assertRaisesRegexp(ValueError, "must be SingleHead"): - head_lib._multi_head((named_head, head_lib._multi_head((named_head,)))) + head_lib.multi_head((named_head, head_lib.multi_head((named_head,)))) def testTrainWithNoneTrainOpFn(self): - head1 = head_lib._multi_class_head( + head1 = head_lib.multi_class_head( n_classes=3, label_name="label1", head_name="head1") - head2 = head_lib._multi_class_head( + head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib._multi_head((head1, head2)) + head = head_lib.multi_head((head1, head2)) labels = { "label1": (1,), "label2": (1,) @@ -1294,11 +1294,11 @@ class MultiHeadTest(test.TestCase): logits=((-0.7, 0.2, .1, .1, .1, .1, .1),)) def testTrain_withNoHeadWeights(self): - head1 = head_lib._multi_class_head( + head1 = head_lib.multi_class_head( n_classes=3, label_name="label1", head_name="head1") - head2 = head_lib._multi_class_head( + head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib._multi_head((head1, head2)) + head = head_lib.multi_head((head1, head2)) labels = { "label1": (1,), "label2": (1,) @@ -1320,11 +1320,11 @@ class MultiHeadTest(test.TestCase): self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) def testTrain_withHeadWeights(self): - head1 = head_lib._multi_class_head( + head1 = head_lib.multi_class_head( n_classes=3, label_name="label1", head_name="head1") - head2 = head_lib._multi_class_head( + head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib._multi_head((head1, head2), (1, .5)) + head = head_lib.multi_head((head1, head2), (1, .5)) labels = { "label1": (1,), "label2": (1,) @@ -1345,11 +1345,11 @@ class MultiHeadTest(test.TestCase): self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3) def testTrain_withDictLogits(self): - head1 = head_lib._multi_class_head( + head1 = head_lib.multi_class_head( n_classes=3, label_name="label1", head_name="head1") - head2 = head_lib._multi_class_head( + head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib._multi_head((head1, head2)) + head = head_lib.multi_head((head1, head2)) labels = { "label1": (1,), "label2": (1,) @@ -1372,11 +1372,11 @@ class MultiHeadTest(test.TestCase): self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3) def testInfer(self): - head1 = head_lib._multi_class_head( + head1 = head_lib.multi_class_head( n_classes=3, label_name="label1", head_name="head1") - head2 = head_lib._multi_class_head( + head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib._multi_head((head1, head2), (1, .5)) + head = head_lib.multi_head((head1, head2), (1, .5)) labels = { "label1": (1,), "label2": (1,) @@ -1422,11 +1422,11 @@ class MultiHeadTest(test.TestCase): ), model_fn_ops.output_alternatives["head2"][1].keys()) def testEval(self): - head1 = head_lib._multi_class_head( + head1 = head_lib.multi_class_head( n_classes=3, label_name="label1", head_name="head1") - head2 = head_lib._multi_class_head( + head2 = head_lib.multi_class_head( n_classes=4, label_name="label2", head_name="head2") - head = head_lib._multi_head((head1, head2), (1, .5)) + head = head_lib.multi_head((head1, head2), (1, .5)) labels = { "label1": (1,), "label2": (1,) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 30e78117a7e..faf78a36752 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -419,7 +419,7 @@ class LinearClassifier(estimator.Estimator): enable_centered_bias = False logging.warning("centered_bias is not supported with SDCA, " "please disable it explicitly.") - head = head_lib._multi_class_head( # pylint: disable=protected-access + head = head_lib.multi_class_head( n_classes, weight_column_name=weight_column_name, enable_centered_bias=enable_centered_bias) @@ -686,7 +686,7 @@ class LinearRegressor(estimator.Estimator): enable_centered_bias = False logging.warning("centered_bias is not supported with SDCA, " "please disable it explicitly.") - head = head_lib._regression_head( # pylint: disable=protected-access + head = head_lib.regression_head( weight_column_name=weight_column_name, label_dimension=label_dimension, enable_centered_bias=enable_centered_bias) @@ -824,8 +824,7 @@ class LinearRegressor(estimator.Estimator): exports_to_keep=exports_to_keep) -# TODO(zakaria): Make it public when b/34751732 is fixed. -class _LinearEstimator(estimator.Estimator): +class LinearEstimator(estimator.Estimator): """Linear model with user specified head. Train a generalized linear model to predict label value given observation of @@ -840,9 +839,9 @@ class _LinearEstimator(estimator.Estimator): sparse_feature_a_x_sparse_feature_b = crossed_column(...) - estimator = _LinearEstimator( + estimator = LinearEstimator( feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b], - head=head_lib._poisson_regression_head()) + head=head_lib.poisson_regression_head()) # Input builders def input_fn_train: # returns x, y @@ -879,7 +878,7 @@ class _LinearEstimator(estimator.Estimator): _joint_weights=False, config=None, feature_engineering_fn=None): - """Construct a `_LinearEstimator` object. + """Construct a `LinearEstimator` object. Args: feature_columns: An iterable containing all the feature columns used by @@ -907,14 +906,14 @@ class _LinearEstimator(estimator.Estimator): into the model. Returns: - A `_LinearEstimator` estimator. + A `LinearEstimator` estimator. Raises: ValueError: if optimizer is not supported, e.g., SDCAOptimizer """ assert feature_columns if isinstance(optimizer, sdca_optimizer.SDCAOptimizer): - raise ValueError("_LinearEstimator does not support SDCA optimizer.") + raise ValueError("LinearEstimator does not support SDCA optimizer.") params = { "head": head, @@ -923,7 +922,7 @@ class _LinearEstimator(estimator.Estimator): "gradient_clip_norm": gradient_clip_norm, "joint_weights": _joint_weights, } - super(_LinearEstimator, self).__init__( + super(LinearEstimator, self).__init__( model_fn=_linear_model_fn, model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 3a559377d66..fc643774528 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -1665,15 +1665,15 @@ class LinearEstimatorTest(test.TestCase): 'feature', dimension=4) ] exp = experiment.Experiment( - estimator=linear._LinearEstimator(feature_columns=cont_features, - head=head_lib._regression_head()), + estimator=linear.LinearEstimator(feature_columns=cont_features, + head=head_lib.regression_head()), train_input_fn=test_data.iris_input_logistic_fn, eval_input_fn=test_data.iris_input_logistic_fn) exp.test() def testEstimatorContract(self): estimator_test_utils.assert_estimator_contract(self, - linear._LinearEstimator) + linear.LinearEstimator) def testLinearRegression(self): """Tests that loss goes down with training.""" @@ -1691,8 +1691,8 @@ class LinearEstimatorTest(test.TestCase): 100) age = feature_column_lib.real_valued_column('age') - linear_estimator = linear._LinearEstimator(feature_columns=[age, language], - head=head_lib._regression_head()) + linear_estimator = linear.LinearEstimator(feature_columns=[age, language], + head=head_lib.regression_head()) linear_estimator.fit(input_fn=input_fn, steps=100) loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss'] linear_estimator.fit(input_fn=input_fn, steps=400) @@ -1717,9 +1717,9 @@ class LinearEstimatorTest(test.TestCase): 100) age = feature_column_lib.real_valued_column('age') - linear_estimator = linear._LinearEstimator( + linear_estimator = linear.LinearEstimator( feature_columns=[age, language], - head=head_lib._poisson_regression_head()) + head=head_lib.poisson_regression_head()) linear_estimator.fit(input_fn=input_fn, steps=10) loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss'] linear_estimator.fit(input_fn=input_fn, steps=100) @@ -1736,8 +1736,8 @@ class LinearEstimatorTest(test.TestCase): sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( example_id_column='example_id') with self.assertRaises(ValueError): - linear._LinearEstimator( - head=head_lib._regression_head(label_dimension=1), + linear.LinearEstimator( + head=head_lib.regression_head(label_dimension=1), feature_columns=[maintenance_cost, sq_footage], optimizer=sdca_optimizer, _joint_weights=True) diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index dfea0e030fa..5a991da8917 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -139,7 +139,7 @@ class SVM(estimator.Estimator): model_dir=model_dir, config=config, params={ - "head": head_lib._binary_svm_head( # pylint: disable=protected-access + "head": head_lib.binary_svm_head( weight_column_name=weight_column_name, enable_centered_bias=False), "feature_columns": feature_columns,