Makes the head/multi-head API public and updates selected users while leaving other users using aliases.
This is CL#1 of a series of CLs to make head/multi-head API public and migrate all users. This CL, + Makes the Head interface and factory functions public. + Updates all tf-learn internal and SIR usage. + Leaves aliases for the legacy private names which will be removed with all existing usages in next CLs. + Also, updates documentation. Change: 149613397
This commit is contained in:
parent
5b3e560d2f
commit
58067591b6
@ -28,13 +28,24 @@ See the @{$python/contrib.learn} guide.
|
|||||||
@@MetricSpec
|
@@MetricSpec
|
||||||
@@PredictionKey
|
@@PredictionKey
|
||||||
@@DNNClassifier
|
@@DNNClassifier
|
||||||
|
@@DNNEstimator
|
||||||
@@DNNRegressor
|
@@DNNRegressor
|
||||||
@@DNNLinearCombinedRegressor
|
@@DNNLinearCombinedRegressor
|
||||||
@@DNNLinearCombinedClassifier
|
@@DNNLinearCombinedClassifier
|
||||||
@@LinearClassifier
|
@@LinearClassifier
|
||||||
|
@@LinearEstimator
|
||||||
@@LinearRegressor
|
@@LinearRegressor
|
||||||
@@LogisticRegressor
|
@@LogisticRegressor
|
||||||
|
|
||||||
|
@@Head
|
||||||
|
@@multi_class_head
|
||||||
|
@@multi_label_head
|
||||||
|
@@binary_svm_head
|
||||||
|
@@regression_head
|
||||||
|
@@poisson_regression_head
|
||||||
|
@@multi_head
|
||||||
|
@@no_op_train_fn
|
||||||
|
|
||||||
@@Experiment
|
@@Experiment
|
||||||
@@ExportStrategy
|
@@ExportStrategy
|
||||||
@@TaskType
|
@@TaskType
|
||||||
|
@ -295,6 +295,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType
|
from tensorflow.contrib.learn.python.learn.estimators.constants import ProblemType
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
|
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNEstimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier
|
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedClassifier
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import DNNLinearCombinedRegressor
|
||||||
@ -304,8 +305,17 @@ from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import binary_svm_head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import Head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_class_head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import multi_label_head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import no_op_train_fn
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import poisson_regression_head
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.head import regression_head
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.kmeans import KMeansClustering
|
from tensorflow.contrib.learn.python.learn.estimators.kmeans import KMeansClustering
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearClassifier
|
||||||
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearEstimator
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.linear import LinearRegressor
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
from tensorflow.contrib.learn.python.learn.estimators.logistic_regressor import LogisticRegressor
|
||||||
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
from tensorflow.contrib.learn.python.learn.estimators.metric_key import MetricKey
|
||||||
|
@ -131,7 +131,7 @@ class ComposableModelTest(test.TestCase):
|
|||||||
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
||||||
age = feature_column.real_valued_column('age')
|
age = feature_column.real_valued_column('age')
|
||||||
|
|
||||||
head = head_lib._multi_class_head(n_classes=2)
|
head = head_lib.multi_class_head(n_classes=2)
|
||||||
classifier = _linear_estimator(head, feature_columns=[age, language])
|
classifier = _linear_estimator(head, feature_columns=[age, language])
|
||||||
|
|
||||||
classifier.fit(input_fn=input_fn, steps=1000)
|
classifier.fit(input_fn=input_fn, steps=1000)
|
||||||
@ -157,7 +157,7 @@ class ComposableModelTest(test.TestCase):
|
|||||||
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
language = feature_column.sparse_column_with_hash_bucket('language', 100)
|
||||||
age = feature_column.sparse_column_with_hash_bucket('age', 2)
|
age = feature_column.sparse_column_with_hash_bucket('age', 2)
|
||||||
|
|
||||||
head = head_lib._multi_class_head(n_classes=2)
|
head = head_lib.multi_class_head(n_classes=2)
|
||||||
classifier = _joint_linear_estimator(head, feature_columns=[age, language])
|
classifier = _joint_linear_estimator(head, feature_columns=[age, language])
|
||||||
|
|
||||||
classifier.fit(input_fn=input_fn, steps=1000)
|
classifier.fit(input_fn=input_fn, steps=1000)
|
||||||
@ -171,7 +171,7 @@ class ComposableModelTest(test.TestCase):
|
|||||||
"""Tests multi-class classification using matrix data as input."""
|
"""Tests multi-class classification using matrix data as input."""
|
||||||
cont_features = [feature_column.real_valued_column('feature', dimension=4)]
|
cont_features = [feature_column.real_valued_column('feature', dimension=4)]
|
||||||
|
|
||||||
head = head_lib._multi_class_head(n_classes=3)
|
head = head_lib.multi_class_head(n_classes=3)
|
||||||
classifier = _dnn_estimator(
|
classifier = _dnn_estimator(
|
||||||
head, feature_columns=cont_features, hidden_units=[3, 3])
|
head, feature_columns=cont_features, hidden_units=[3, 3])
|
||||||
|
|
||||||
|
@ -304,7 +304,7 @@ class DNNClassifier(estimator.Estimator):
|
|||||||
config=config,
|
config=config,
|
||||||
params={
|
params={
|
||||||
"head":
|
"head":
|
||||||
head_lib._multi_class_head( # pylint: disable=protected-access
|
head_lib.multi_class_head(
|
||||||
n_classes,
|
n_classes,
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
enable_centered_bias=enable_centered_bias),
|
enable_centered_bias=enable_centered_bias),
|
||||||
@ -579,7 +579,7 @@ class DNNRegressor(estimator.Estimator):
|
|||||||
config=config,
|
config=config,
|
||||||
params={
|
params={
|
||||||
"head":
|
"head":
|
||||||
head_lib._regression_head( # pylint: disable=protected-access
|
head_lib.regression_head(
|
||||||
label_dimension=label_dimension,
|
label_dimension=label_dimension,
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
enable_centered_bias=enable_centered_bias),
|
enable_centered_bias=enable_centered_bias),
|
||||||
@ -731,8 +731,7 @@ class DNNRegressor(estimator.Estimator):
|
|||||||
exports_to_keep=exports_to_keep)
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
|
|
||||||
# TODO(zakaria): Make it public when b/34751732 is fixed.
|
class DNNEstimator(estimator.Estimator):
|
||||||
class _DNNEstimator(estimator.Estimator):
|
|
||||||
"""A Estimator for TensorFlow DNN models with user specified _Head.
|
"""A Estimator for TensorFlow DNN models with user specified _Head.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -745,20 +744,20 @@ class _DNNEstimator(estimator.Estimator):
|
|||||||
...)
|
...)
|
||||||
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
|
sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,
|
||||||
...)
|
...)
|
||||||
To create a _DNNEstimator for binary classification, where
|
To create a DNNEstimator for binary classification, where
|
||||||
estimator = _DNNEstimator(
|
estimator = DNNEstimator(
|
||||||
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
||||||
head=head_lib._multi_class__head(n_classes=2),
|
head=tf.contrib.learn.multi_class_head(n_classes=2),
|
||||||
hidden_units=[1024, 512, 256])
|
hidden_units=[1024, 512, 256])
|
||||||
|
|
||||||
If your label is keyed with "y" in your labels dict, and weights are keyed
|
If your label is keyed with "y" in your labels dict, and weights are keyed
|
||||||
with "w" in features dict, and you want to enable centered bias,
|
with "w" in features dict, and you want to enable centered bias,
|
||||||
head = head_lib._multi_class__head(
|
head = tf.contrib.learn.multi_class_head(
|
||||||
n_classes=2,
|
n_classes=2,
|
||||||
label_name="x",
|
label_name="x",
|
||||||
weight_column_name="w",
|
weight_column_name="w",
|
||||||
enable_centered_bias=True)
|
enable_centered_bias=True)
|
||||||
estimator = _DNNEstimator(
|
estimator = DNNEstimator(
|
||||||
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
|
||||||
head=head,
|
head=head,
|
||||||
hidden_units=[1024, 512, 256])
|
hidden_units=[1024, 512, 256])
|
||||||
@ -802,10 +801,10 @@ class _DNNEstimator(estimator.Estimator):
|
|||||||
feature_engineering_fn=None,
|
feature_engineering_fn=None,
|
||||||
embedding_lr_multipliers=None,
|
embedding_lr_multipliers=None,
|
||||||
input_layer_min_slice_size=None):
|
input_layer_min_slice_size=None):
|
||||||
"""Initializes a _DNNEstimator instance.
|
"""Initializes a `DNNEstimator` instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
head: _Head instance.
|
head: `Head` instance.
|
||||||
hidden_units: List of hidden units per layer. All layers are fully
|
hidden_units: List of hidden units per layer. All layers are fully
|
||||||
connected. Ex. `[64, 32]` means first layer has 64 nodes and second one
|
connected. Ex. `[64, 32]` means first layer has 64 nodes and second one
|
||||||
has 32.
|
has 32.
|
||||||
@ -836,9 +835,9 @@ class _DNNEstimator(estimator.Estimator):
|
|||||||
partitions. If not provided, will use the default of 64M.
|
partitions. If not provided, will use the default of 64M.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `_DNNEstimator` estimator.
|
A `DNNEstimator` estimator.
|
||||||
"""
|
"""
|
||||||
super(_DNNEstimator, self).__init__(
|
super(DNNEstimator, self).__init__(
|
||||||
model_fn=_dnn_model_fn,
|
model_fn=_dnn_model_fn,
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
@ -854,4 +853,3 @@ class _DNNEstimator(estimator.Estimator):
|
|||||||
"input_layer_min_slice_size": input_layer_min_slice_size,
|
"input_layer_min_slice_size": input_layer_min_slice_size,
|
||||||
},
|
},
|
||||||
feature_engineering_fn=feature_engineering_fn)
|
feature_engineering_fn=feature_engineering_fn)
|
||||||
|
|
||||||
|
@ -550,7 +550,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
|||||||
if not self._feature_columns:
|
if not self._feature_columns:
|
||||||
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||||
"must be defined.")
|
"must be defined.")
|
||||||
head = head_lib._multi_class_head( # pylint: disable=protected-access
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
enable_centered_bias=enable_centered_bias)
|
enable_centered_bias=enable_centered_bias)
|
||||||
@ -841,7 +841,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
|
|||||||
if not self._feature_columns:
|
if not self._feature_columns:
|
||||||
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||||
"must be defined.")
|
"must be defined.")
|
||||||
head = head_lib._regression_head( # pylint: disable=protected-access
|
head = head_lib.regression_head(
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
label_dimension=label_dimension,
|
label_dimension=label_dimension,
|
||||||
enable_centered_bias=enable_centered_bias)
|
enable_centered_bias=enable_centered_bias)
|
||||||
|
@ -60,7 +60,7 @@ def _assert_metrics_in_range(keys, metrics):
|
|||||||
metrics)
|
metrics)
|
||||||
|
|
||||||
|
|
||||||
class _CheckCallsHead(head_lib._Head): # pylint: disable=protected-access
|
class _CheckCallsHead(head_lib.Head):
|
||||||
"""Head that checks whether head_ops is called."""
|
"""Head that checks whether head_ops is called."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -97,7 +97,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
'dnn_feature_columns': [one_hot_language],
|
'dnn_feature_columns': [one_hot_language],
|
||||||
'head': head_lib._multi_class_head(2),
|
'head': head_lib.multi_class_head(2),
|
||||||
'dnn_hidden_units': [1],
|
'dnn_hidden_units': [1],
|
||||||
# Set lr mult to 0. to keep embeddings constant.
|
# Set lr mult to 0. to keep embeddings constant.
|
||||||
'embedding_lr_multipliers': {
|
'embedding_lr_multipliers': {
|
||||||
@ -131,7 +131,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
'dnn_feature_columns': [embedding_language, embedding_wire],
|
'dnn_feature_columns': [embedding_language, embedding_wire],
|
||||||
'head': head_lib._multi_class_head(2),
|
'head': head_lib.multi_class_head(2),
|
||||||
'dnn_hidden_units': [1],
|
'dnn_hidden_units': [1],
|
||||||
# Set lr mult to 0. to keep embeddings constant.
|
# Set lr mult to 0. to keep embeddings constant.
|
||||||
'embedding_lr_multipliers': {
|
'embedding_lr_multipliers': {
|
||||||
|
@ -59,7 +59,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
'feature_columns': [one_hot_language],
|
'feature_columns': [one_hot_language],
|
||||||
'head': head_lib._multi_class_head(2),
|
'head': head_lib.multi_class_head(2),
|
||||||
'hidden_units': [1],
|
'hidden_units': [1],
|
||||||
# Set lr mult to 0. to keep embeddings constant.
|
# Set lr mult to 0. to keep embeddings constant.
|
||||||
'embedding_lr_multipliers': {
|
'embedding_lr_multipliers': {
|
||||||
@ -90,7 +90,7 @@ class EmbeddingMultiplierTest(test.TestCase):
|
|||||||
|
|
||||||
params = {
|
params = {
|
||||||
'feature_columns': [embedding_language, embedding_wire],
|
'feature_columns': [embedding_language, embedding_wire],
|
||||||
'head': head_lib._multi_class_head(2),
|
'head': head_lib.multi_class_head(2),
|
||||||
'hidden_units': [1],
|
'hidden_units': [1],
|
||||||
# Set lr mult to 0. to keep embeddings constant.
|
# Set lr mult to 0. to keep embeddings constant.
|
||||||
'embedding_lr_multipliers': {
|
'embedding_lr_multipliers': {
|
||||||
@ -145,7 +145,7 @@ class DNNEstimatorTest(test.TestCase):
|
|||||||
exp.test()
|
exp.test()
|
||||||
|
|
||||||
def testEstimatorContract(self):
|
def testEstimatorContract(self):
|
||||||
estimator_test_utils.assert_estimator_contract(self, dnn._DNNEstimator)
|
estimator_test_utils.assert_estimator_contract(self, dnn.DNNEstimator)
|
||||||
|
|
||||||
def testTrainWithWeights(self):
|
def testTrainWithWeights(self):
|
||||||
"""Tests training with given weight column."""
|
"""Tests training with given weight column."""
|
||||||
@ -172,8 +172,8 @@ class DNNEstimatorTest(test.TestCase):
|
|||||||
}
|
}
|
||||||
return features, labels
|
return features, labels
|
||||||
|
|
||||||
dnn_estimator = dnn._DNNEstimator(
|
dnn_estimator = dnn.DNNEstimator(
|
||||||
head=head_lib._multi_class_head(2, weight_column_name='w'),
|
head=head_lib.multi_class_head(2, weight_column_name='w'),
|
||||||
feature_columns=[feature_column.real_valued_column('x')],
|
feature_columns=[feature_column.real_valued_column('x')],
|
||||||
hidden_units=[3, 3],
|
hidden_units=[3, 3],
|
||||||
config=run_config.RunConfig(tf_random_seed=1))
|
config=run_config.RunConfig(tf_random_seed=1))
|
||||||
|
@ -46,15 +46,141 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.summary import summary
|
from tensorflow.python.summary import summary
|
||||||
from tensorflow.python.training import training
|
from tensorflow.python.training import training
|
||||||
|
|
||||||
# TODO(zakaria): add functions that creates a head and returns ModelOpFn
|
|
||||||
|
class Head(object):
|
||||||
|
"""Interface for the head/top of a model.
|
||||||
|
|
||||||
|
Given logits (or output of a hidden layer), a Head knows how to compute
|
||||||
|
predictions, loss, default metric and export signature. It is meant to,
|
||||||
|
|
||||||
|
1) Simplify writing model_fn and to make model_fn more configurable
|
||||||
|
2) Support wide range of machine learning models. Since most heads can work
|
||||||
|
with logits, they can support DNN, RNN, Wide, Wide&Deep,
|
||||||
|
Global objectives, Gradient boosted trees and many other types
|
||||||
|
of machine learning models.
|
||||||
|
2) To allow users to seamlessly switch between 1 to n heads for multi
|
||||||
|
objective learning (See _MultiHead implementation for more details)
|
||||||
|
|
||||||
|
Common usage:
|
||||||
|
Here is simplified model_fn to build a multiclass DNN model.
|
||||||
|
```python
|
||||||
|
def _my_dnn_model_fn(features, labels, mode, params, config=None):
|
||||||
|
# Optionally your callers can pass head to model_fn as a param.
|
||||||
|
head = tf.contrib.learn.multi_class_head(...)
|
||||||
|
input = tf.contrib.layers.input_from_feature_columns(features, ...)
|
||||||
|
last_hidden_layer_out = tf.contrib.layers.stack(
|
||||||
|
input, tf.contrib.layers.fully_connected, [1000, 500])
|
||||||
|
logits = tf.contrib.layers.fully_connected(
|
||||||
|
last_hidden_layer_out, head.logits_dimension, activation_fn=None)
|
||||||
|
|
||||||
|
def _train_op_fn(loss):
|
||||||
|
return optimizer.minimize(loss)
|
||||||
|
|
||||||
|
return head.create_model_fn_ops(
|
||||||
|
features=features,
|
||||||
|
labels=labels,
|
||||||
|
mode=mode,
|
||||||
|
train_op_fn=_train_op_fn,
|
||||||
|
logits=logits,
|
||||||
|
scope=...)
|
||||||
|
```
|
||||||
|
|
||||||
|
Most heads also support logits_input which is typically the output of the last
|
||||||
|
hidden layer. Some heads (like heads responsible for candidate sampling or
|
||||||
|
hierarchical softmax) intrinsically will not support logits and you have
|
||||||
|
to pass logits_input. Here is a common usage,
|
||||||
|
```python
|
||||||
|
return head.create_model_fn_ops(
|
||||||
|
features=features,
|
||||||
|
labels=labels,
|
||||||
|
mode=mode,
|
||||||
|
train_op_fn=_train_op_fn,
|
||||||
|
logits_input=last_hidden_layer_out,
|
||||||
|
scope=...)
|
||||||
|
```python
|
||||||
|
|
||||||
|
There are cases where computing and applying gradients can not be meaningfully
|
||||||
|
captured with train_op_fn we support (for example, with sync optimizer). In
|
||||||
|
such case, you can take the responsibility on your own. Here is a common
|
||||||
|
use case,
|
||||||
|
```python
|
||||||
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
|
features=features,
|
||||||
|
labels=labels,
|
||||||
|
mode=mode,
|
||||||
|
train_op_fn=tf.contrib.learn.no_op_train_fn,
|
||||||
|
logits=logits,
|
||||||
|
scope=...)
|
||||||
|
if mode == tf.contrib.learn.ModeKeys.TRAIN:
|
||||||
|
optimizer = ...
|
||||||
|
sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)
|
||||||
|
update_op = tf.contrib.layers.optimize_loss(optimizer=sync,
|
||||||
|
loss=model_fn_ops.loss, ...)
|
||||||
|
hooks = [sync.make_session_run_hook(is_chief)]
|
||||||
|
... upate train_op and hooks in ModelFnOps and return
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def logits_dimension(self):
|
||||||
|
"""Size of the last dimension of the logits `Tensor`.
|
||||||
|
|
||||||
|
Typically, logits is of shape `[batch_size, logits_dimension]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The expected size of the `logits` tensor.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Calling an abstract method.")
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def create_model_fn_ops(self,
|
||||||
|
features,
|
||||||
|
mode,
|
||||||
|
labels=None,
|
||||||
|
train_op_fn=None,
|
||||||
|
logits=None,
|
||||||
|
logits_input=None,
|
||||||
|
scope=None):
|
||||||
|
"""Returns `ModelFnOps` that a model_fn can return.
|
||||||
|
|
||||||
|
Please note that,
|
||||||
|
+ Exactly one of `logits` and `logits_input` must be provided.
|
||||||
|
+ All args must be passed via name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Input `dict` of `Tensor` objects.
|
||||||
|
mode: Estimator's `ModeKeys`.
|
||||||
|
labels: Labels `Tensor`, or `dict` of same.
|
||||||
|
train_op_fn: Function that takes a scalar loss `Tensor` and returns an op
|
||||||
|
to optimize the model with the loss. This is used in TRAIN mode and
|
||||||
|
must not be None. None is allowed in other modes. If you want to
|
||||||
|
optimize loss yourself you can pass `no_op_train_fn` and then use
|
||||||
|
ModeFnOps.loss to compute and apply gradients.
|
||||||
|
logits: logits `Tensor` to be used by the head.
|
||||||
|
logits_input: `Tensor` from which to build logits, often needed when you
|
||||||
|
don't want to compute the logits. Typicaly this is the activation of the
|
||||||
|
last hidden layer in a DNN. Some heads (like the ones responsible for
|
||||||
|
candidate sampling) intrinsically avoid computing full logits and only
|
||||||
|
accepts logits_input.
|
||||||
|
scope: Optional scope for `variable_scope`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `ModelFnOps`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `mode` is not recognized.
|
||||||
|
ValueError: If neither or both of `logits` and `logits_input` is provided.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Calling an abstract method.")
|
||||||
|
|
||||||
|
|
||||||
def _regression_head(label_name=None,
|
def regression_head(label_name=None,
|
||||||
weight_column_name=None,
|
weight_column_name=None,
|
||||||
label_dimension=1,
|
label_dimension=1,
|
||||||
enable_centered_bias=False,
|
enable_centered_bias=False,
|
||||||
head_name=None):
|
head_name=None):
|
||||||
"""Creates a _Head for linear regression.
|
"""Creates a `Head` for linear regression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_name: String, name of the key in label dict. Can be null if label
|
label_name: String, name of the key in label dict. Can be null if label
|
||||||
@ -73,7 +199,7 @@ def _regression_head(label_name=None,
|
|||||||
will be `head_name`.
|
will be `head_name`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of _Head
|
An instance of `Head` for linear regression.
|
||||||
"""
|
"""
|
||||||
return _RegressionHead(
|
return _RegressionHead(
|
||||||
label_name=label_name,
|
label_name=label_name,
|
||||||
@ -85,12 +211,12 @@ def _regression_head(label_name=None,
|
|||||||
link_fn=array_ops.identity)
|
link_fn=array_ops.identity)
|
||||||
|
|
||||||
|
|
||||||
def _poisson_regression_head(label_name=None,
|
def poisson_regression_head(label_name=None,
|
||||||
weight_column_name=None,
|
weight_column_name=None,
|
||||||
label_dimension=1,
|
label_dimension=1,
|
||||||
enable_centered_bias=False,
|
enable_centered_bias=False,
|
||||||
head_name=None):
|
head_name=None):
|
||||||
"""Creates a _Head for linear regression.
|
"""Creates a `Head` for poisson regression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_name: String, name of the key in label dict. Can be null if label
|
label_name: String, name of the key in label dict. Can be null if label
|
||||||
@ -109,7 +235,7 @@ def _poisson_regression_head(label_name=None,
|
|||||||
will be `head_name`.
|
will be `head_name`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of _Head
|
An instance of `Head` for poisson regression.
|
||||||
"""
|
"""
|
||||||
return _RegressionHead(
|
return _RegressionHead(
|
||||||
label_name=label_name,
|
label_name=label_name,
|
||||||
@ -120,18 +246,18 @@ def _poisson_regression_head(label_name=None,
|
|||||||
loss_fn=_poisson_loss,
|
loss_fn=_poisson_loss,
|
||||||
link_fn=math_ops.exp)
|
link_fn=math_ops.exp)
|
||||||
|
|
||||||
# TODO(zakaria): Add logistic_regression_head
|
# TODO(zakaria): Consider adding a _RegressionHead for logistic_regression
|
||||||
|
|
||||||
|
|
||||||
def _multi_class_head(n_classes,
|
def multi_class_head(n_classes,
|
||||||
label_name=None,
|
label_name=None,
|
||||||
weight_column_name=None,
|
weight_column_name=None,
|
||||||
enable_centered_bias=False,
|
enable_centered_bias=False,
|
||||||
head_name=None,
|
head_name=None,
|
||||||
thresholds=None,
|
thresholds=None,
|
||||||
metric_class_ids=None,
|
metric_class_ids=None,
|
||||||
loss_fn=None):
|
loss_fn=None):
|
||||||
"""Creates a _Head for multi class single label classification.
|
"""Creates a `Head` for multi class single label classification.
|
||||||
|
|
||||||
The Head uses softmax cross entropy loss.
|
The Head uses softmax cross entropy loss.
|
||||||
|
|
||||||
@ -157,7 +283,7 @@ def _multi_class_head(n_classes,
|
|||||||
optional. See `tf.losses`
|
optional. See `tf.losses`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of _MultiClassHead.
|
An instance of `Head` for multi class classification.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when
|
ValueError: If `n_classes` is < 2, or `metric_class_ids` is provided when
|
||||||
@ -193,13 +319,13 @@ def _multi_class_head(n_classes,
|
|||||||
loss_fn=loss_fn)
|
loss_fn=loss_fn)
|
||||||
|
|
||||||
|
|
||||||
def _binary_svm_head(
|
def binary_svm_head(
|
||||||
label_name=None,
|
label_name=None,
|
||||||
weight_column_name=None,
|
weight_column_name=None,
|
||||||
enable_centered_bias=False,
|
enable_centered_bias=False,
|
||||||
head_name=None,
|
head_name=None,
|
||||||
thresholds=None,):
|
thresholds=None,):
|
||||||
"""Creates a `_Head` for binary classification with SVMs.
|
"""Creates a `Head` for binary classification with SVMs.
|
||||||
|
|
||||||
The head uses binary hinge loss.
|
The head uses binary hinge loss.
|
||||||
|
|
||||||
@ -218,8 +344,7 @@ def _binary_svm_head(
|
|||||||
thresholds: thresholds for eval metrics, defaults to [.5]
|
thresholds: thresholds for eval metrics, defaults to [.5]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of `_Head`.
|
An instance of `Head` for binary classification with SVM.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return _BinarySvmHead(
|
return _BinarySvmHead(
|
||||||
label_name=label_name,
|
label_name=label_name,
|
||||||
@ -229,15 +354,15 @@ def _binary_svm_head(
|
|||||||
thresholds=thresholds)
|
thresholds=thresholds)
|
||||||
|
|
||||||
|
|
||||||
def _multi_label_head(n_classes,
|
def multi_label_head(n_classes,
|
||||||
label_name=None,
|
label_name=None,
|
||||||
weight_column_name=None,
|
weight_column_name=None,
|
||||||
enable_centered_bias=False,
|
enable_centered_bias=False,
|
||||||
head_name=None,
|
head_name=None,
|
||||||
thresholds=None,
|
thresholds=None,
|
||||||
metric_class_ids=None,
|
metric_class_ids=None,
|
||||||
loss_fn=None):
|
loss_fn=None):
|
||||||
"""Creates a _Head for multi label classification.
|
"""Creates a Head for multi label classification.
|
||||||
|
|
||||||
The Head uses sigmoid cross entropy loss.
|
The Head uses sigmoid cross entropy loss.
|
||||||
|
|
||||||
@ -262,7 +387,7 @@ def _multi_label_head(n_classes,
|
|||||||
optional. See `tf.losses`
|
optional. See `tf.losses`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of _MultiLabelHead.
|
An instance of `Head` for multi label classification.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If n_classes is < 2
|
ValueError: If n_classes is < 2
|
||||||
@ -284,16 +409,16 @@ def _multi_label_head(n_classes,
|
|||||||
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None)
|
||||||
|
|
||||||
|
|
||||||
def _multi_head(heads, loss_weights=None):
|
def multi_head(heads, loss_weights=None):
|
||||||
"""Creates a MultiHead stemming from same logits/hidden layer.
|
"""Creates a MultiHead stemming from same logits/hidden layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
heads: list of _Head objects.
|
heads: list of Head objects.
|
||||||
loss_weights: optional list of weights to be used to combine losses from
|
loss_weights: optional list of weights to be used to merge losses from
|
||||||
each head. All losses are weighted equally if not provided.
|
each head. All losses are weighted equally if not provided.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A _Head instance that combines multiple heads.
|
A instance of `Head` that merges multiple heads.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if heads and loss_weights have different size.
|
ValueError: if heads and loss_weights have different size.
|
||||||
@ -302,7 +427,7 @@ def _multi_head(heads, loss_weights=None):
|
|||||||
if len(loss_weights) != len(heads):
|
if len(loss_weights) != len(heads):
|
||||||
raise ValueError("heads and loss_weights must have same size")
|
raise ValueError("heads and loss_weights must have same size")
|
||||||
|
|
||||||
def _weighted_loss_combiner(losses):
|
def _weighted_loss_merger(losses):
|
||||||
if loss_weights:
|
if loss_weights:
|
||||||
if len(losses) != len(loss_weights):
|
if len(losses) != len(loss_weights):
|
||||||
raise ValueError("losses and loss_weights must have same size")
|
raise ValueError("losses and loss_weights must have same size")
|
||||||
@ -313,7 +438,7 @@ def _multi_head(heads, loss_weights=None):
|
|||||||
else:
|
else:
|
||||||
return math_ops.add_n(losses)
|
return math_ops.add_n(losses)
|
||||||
|
|
||||||
return _MultiHead(heads, loss_combiner=_weighted_loss_combiner)
|
return _MultiHead(heads, loss_merger=_weighted_loss_merger)
|
||||||
|
|
||||||
|
|
||||||
def no_op_train_fn(loss):
|
def no_op_train_fn(loss):
|
||||||
@ -321,64 +446,7 @@ def no_op_train_fn(loss):
|
|||||||
return control_flow_ops.no_op()
|
return control_flow_ops.no_op()
|
||||||
|
|
||||||
|
|
||||||
# TODO(zakaria): Make the classes public once we are ready for users to subclass
|
class _SingleHead(Head):
|
||||||
# them. See b/34751732
|
|
||||||
class _Head(object):
|
|
||||||
"""Interface for the head/top of a model.
|
|
||||||
|
|
||||||
Given logits or output of a hidden layer, a Head knows how to compute
|
|
||||||
predictions, loss, default metric and export signature.
|
|
||||||
"""
|
|
||||||
__metaclass__ = abc.ABCMeta
|
|
||||||
|
|
||||||
@abc.abstractproperty
|
|
||||||
def logits_dimension(self):
|
|
||||||
"""Size of the last dimension of the logits `Tensor`.
|
|
||||||
|
|
||||||
Typically, logits is of shape `[batch_size, logits_dimension]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of logits values per example.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Calling an abstract method.")
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def create_model_fn_ops(self,
|
|
||||||
features,
|
|
||||||
mode,
|
|
||||||
labels=None,
|
|
||||||
train_op_fn=None,
|
|
||||||
logits=None,
|
|
||||||
logits_input=None,
|
|
||||||
scope=None):
|
|
||||||
"""Returns ops for a model_fn.
|
|
||||||
|
|
||||||
Exactly one of `logits` and `logits_input` must be provided.
|
|
||||||
|
|
||||||
All args must be passed via name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
features: Input `dict` of `Tensor` objects.
|
|
||||||
mode: Estimator's `ModeKeys`.
|
|
||||||
labels: Labels `Tensor`, or `dict` of same.
|
|
||||||
train_op_fn: Function that takes a scalar loss and returns an op to
|
|
||||||
optimize with the loss. Must not be `None` in TRAIN mode. If you want
|
|
||||||
to optimize loss yourself you can pass `no_op_train_fn`.
|
|
||||||
logits: logits `Tensor`, or `dict` of same, to be used for the head.
|
|
||||||
logits_input: `Tensor` from which to build logits.
|
|
||||||
scope: Optional scope for `variable_scope`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`ModelFnOps`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `mode` is not recognized, or neither or both of `logits`
|
|
||||||
and `logits_input` is provided.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Calling an abstract method.")
|
|
||||||
|
|
||||||
|
|
||||||
class _SingleHead(_Head):
|
|
||||||
"""Interface for a single head/top of a model."""
|
"""Interface for a single head/top of a model."""
|
||||||
__metaclass__ = abc.ABCMeta
|
__metaclass__ = abc.ABCMeta
|
||||||
|
|
||||||
@ -565,7 +633,7 @@ def _create_model_fn_ops(features,
|
|||||||
|
|
||||||
|
|
||||||
class _RegressionHead(_SingleHead):
|
class _RegressionHead(_SingleHead):
|
||||||
"""_Head for regression with a generalized linear model."""
|
"""`Head` for regression with a generalized linear model."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
label_dimension,
|
label_dimension,
|
||||||
@ -575,7 +643,7 @@ class _RegressionHead(_SingleHead):
|
|||||||
weight_column_name=None,
|
weight_column_name=None,
|
||||||
enable_centered_bias=False,
|
enable_centered_bias=False,
|
||||||
head_name=None):
|
head_name=None):
|
||||||
"""Head for regression.
|
"""`Head` for regression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_dimension: Number of regression labels per example. This is the
|
label_dimension: Number of regression labels per example. This is the
|
||||||
@ -614,7 +682,7 @@ class _RegressionHead(_SingleHead):
|
|||||||
logits=None,
|
logits=None,
|
||||||
logits_input=None,
|
logits_input=None,
|
||||||
scope=None):
|
scope=None):
|
||||||
"""See `_Head`."""
|
"""See `Head`."""
|
||||||
return _create_model_fn_ops(
|
return _create_model_fn_ops(
|
||||||
features=features,
|
features=features,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
@ -682,7 +750,7 @@ def _one_class_to_two_class_logits(logits):
|
|||||||
|
|
||||||
|
|
||||||
class _BinaryLogisticHead(_SingleHead):
|
class _BinaryLogisticHead(_SingleHead):
|
||||||
"""_Head for binary logistic classifciation."""
|
"""`Head` for binary classification with logistic regression."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
label_name=None,
|
label_name=None,
|
||||||
@ -691,7 +759,7 @@ class _BinaryLogisticHead(_SingleHead):
|
|||||||
head_name=None,
|
head_name=None,
|
||||||
loss_fn=None,
|
loss_fn=None,
|
||||||
thresholds=None):
|
thresholds=None):
|
||||||
"""Base type for all single heads.
|
"""`Head` for binary classification with logistic regression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_name: String, name of the key in label dict. Can be `None` if label
|
label_name: String, name of the key in label dict. Can be `None` if label
|
||||||
@ -729,7 +797,7 @@ class _BinaryLogisticHead(_SingleHead):
|
|||||||
logits=None,
|
logits=None,
|
||||||
logits_input=None,
|
logits_input=None,
|
||||||
scope=None):
|
scope=None):
|
||||||
"""See `_Head`."""
|
"""See `Head`."""
|
||||||
return _create_model_fn_ops(
|
return _create_model_fn_ops(
|
||||||
features=features,
|
features=features,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
@ -844,7 +912,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None):
|
|||||||
|
|
||||||
|
|
||||||
class _MultiClassHead(_SingleHead):
|
class _MultiClassHead(_SingleHead):
|
||||||
"""_Head for classification."""
|
"""'Head' for multi class classification."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_classes,
|
n_classes,
|
||||||
@ -855,7 +923,7 @@ class _MultiClassHead(_SingleHead):
|
|||||||
loss_fn=None,
|
loss_fn=None,
|
||||||
thresholds=None,
|
thresholds=None,
|
||||||
metric_class_ids=None):
|
metric_class_ids=None):
|
||||||
"""_Head for classification.
|
"""'Head' for multi class classification.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
|
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
|
||||||
@ -905,7 +973,7 @@ class _MultiClassHead(_SingleHead):
|
|||||||
logits=None,
|
logits=None,
|
||||||
logits_input=None,
|
logits_input=None,
|
||||||
scope=None):
|
scope=None):
|
||||||
"""See `_Head`."""
|
"""See `Head`."""
|
||||||
return _create_model_fn_ops(
|
return _create_model_fn_ops(
|
||||||
features=features,
|
features=features,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
@ -1039,7 +1107,7 @@ def _assert_labels_rank(labels):
|
|||||||
|
|
||||||
|
|
||||||
class _BinarySvmHead(_SingleHead):
|
class _BinarySvmHead(_SingleHead):
|
||||||
"""_Head for binary classification using SVMs."""
|
"""`Head` for binary classification using SVM."""
|
||||||
|
|
||||||
def __init__(self, label_name, weight_column_name, enable_centered_bias,
|
def __init__(self, label_name, weight_column_name, enable_centered_bias,
|
||||||
head_name, thresholds):
|
head_name, thresholds):
|
||||||
@ -1069,7 +1137,7 @@ class _BinarySvmHead(_SingleHead):
|
|||||||
logits=None,
|
logits=None,
|
||||||
logits_input=None,
|
logits_input=None,
|
||||||
scope=None):
|
scope=None):
|
||||||
"""See `_Head`."""
|
"""See `Head`."""
|
||||||
return _create_model_fn_ops(
|
return _create_model_fn_ops(
|
||||||
features=features,
|
features=features,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
@ -1125,7 +1193,7 @@ class _BinarySvmHead(_SingleHead):
|
|||||||
|
|
||||||
|
|
||||||
class _MultiLabelHead(_SingleHead):
|
class _MultiLabelHead(_SingleHead):
|
||||||
"""_Head for multlabel classification."""
|
"""`Head` for multi-label classification."""
|
||||||
|
|
||||||
# TODO(zakaria): add signature and metric for multilabel.
|
# TODO(zakaria): add signature and metric for multilabel.
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -1162,7 +1230,7 @@ class _MultiLabelHead(_SingleHead):
|
|||||||
logits=None,
|
logits=None,
|
||||||
logits_input=None,
|
logits_input=None,
|
||||||
scope=None):
|
scope=None):
|
||||||
"""See `_Head`."""
|
"""See `Head`."""
|
||||||
return _create_model_fn_ops(
|
return _create_model_fn_ops(
|
||||||
features=features,
|
features=features,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
@ -1240,24 +1308,52 @@ class _MultiLabelHead(_SingleHead):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
class _MultiHead(_Head):
|
class _MultiHead(Head):
|
||||||
"""_Head to combine multiple _Head objects.
|
"""`Head` implementation for multi objective learning.
|
||||||
|
|
||||||
|
This class is responsible for using and merging the output of multiple
|
||||||
|
`Head` objects.
|
||||||
|
|
||||||
All heads stem from the same logits/logit_input tensor.
|
All heads stem from the same logits/logit_input tensor.
|
||||||
|
|
||||||
For training, combines losses of each heads according a function provided by
|
Common usage:
|
||||||
user.
|
For simple use cases you can pass the activation of hidden layer like
|
||||||
For eval, adds a /head_name suffix to the keys in eval metrics.
|
this from your model_fn,
|
||||||
For inference, updates keys prediction dict to a 2-tuple,
|
```python
|
||||||
(head_name, prediction_key)
|
last_hidden_layer_activation = ... Build your model.
|
||||||
|
multi_head = ...
|
||||||
|
return multi_head.create_model_fn_ops(
|
||||||
|
..., logits_input=last_hidden_layer_activation, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
Or you can create a logits tensor of
|
||||||
|
[batch_size, multi_head.logits_dimension] shape. _MultiHead will split the
|
||||||
|
logits for you.
|
||||||
|
return multi_head.create_model_fn_ops(..., logits=logits, ...)
|
||||||
|
|
||||||
|
For more complex use cases like a multi-task/multi-tower model or when logits
|
||||||
|
for each head has to be created separately, you can pass a dict of logits
|
||||||
|
where the keys match the name of the single heads.
|
||||||
|
```python
|
||||||
|
logits = {"head1": logits1, "head2": logits2}
|
||||||
|
return multi_head.create_model_fn_ops(..., logits=logits, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is what this class does,
|
||||||
|
+ For training, merges losses of each heads according a function provided by
|
||||||
|
user, calls user provided train_op_fn with this final loss.
|
||||||
|
+ For eval, merges metrics by adding head_name suffix to the keys in eval
|
||||||
|
metrics.
|
||||||
|
+ For inference, updates keys in prediction dict to a 2-tuple,
|
||||||
|
(head_name, prediction_key)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, heads, loss_combiner):
|
def __init__(self, heads, loss_merger):
|
||||||
"""_Head to combine multiple _Head objects.
|
"""_Head to merges multiple _Head objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
heads: list of _Head objects.
|
heads: list of _Head objects.
|
||||||
loss_combiner: function that takes a list of loss tensors for the heads
|
loss_merger: function that takes a list of loss tensors for the heads
|
||||||
and returns the final loss tensor for the multi head.
|
and returns the final loss tensor for the multi head.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -1274,7 +1370,7 @@ class _MultiHead(_Head):
|
|||||||
self._logits_dimension += head.logits_dimension
|
self._logits_dimension += head.logits_dimension
|
||||||
|
|
||||||
self._heads = heads
|
self._heads = heads
|
||||||
self._loss_combiner = loss_combiner
|
self._loss_merger = loss_merger
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logits_dimension(self):
|
def logits_dimension(self):
|
||||||
@ -1353,11 +1449,11 @@ class _MultiHead(_Head):
|
|||||||
if mode == model_fn.ModeKeys.TRAIN:
|
if mode == model_fn.ModeKeys.TRAIN:
|
||||||
if train_op_fn is None:
|
if train_op_fn is None:
|
||||||
raise ValueError("train_op_fn can not be None in TRAIN mode.")
|
raise ValueError("train_op_fn can not be None in TRAIN mode.")
|
||||||
return self._combine_train(all_model_fn_ops, train_op_fn)
|
return self._merge_train(all_model_fn_ops, train_op_fn)
|
||||||
if mode == model_fn.ModeKeys.INFER:
|
if mode == model_fn.ModeKeys.INFER:
|
||||||
return self._combine_infer(all_model_fn_ops)
|
return self._merge_infer(all_model_fn_ops)
|
||||||
if mode == model_fn.ModeKeys.EVAL:
|
if mode == model_fn.ModeKeys.EVAL:
|
||||||
return self._combine_eval(all_model_fn_ops)
|
return self._merge_eval(all_model_fn_ops)
|
||||||
raise ValueError("mode=%s unrecognized" % str(mode))
|
raise ValueError("mode=%s unrecognized" % str(mode))
|
||||||
|
|
||||||
def _split_logits(self, logits):
|
def _split_logits(self, logits):
|
||||||
@ -1379,8 +1475,8 @@ class _MultiHead(_Head):
|
|||||||
begin += current_logits_size
|
begin += current_logits_size
|
||||||
return all_logits
|
return all_logits
|
||||||
|
|
||||||
def _combine_train(self, all_model_fn_ops, train_op_fn):
|
def _merge_train(self, all_model_fn_ops, train_op_fn):
|
||||||
"""Combines list of ModelFnOps for training.
|
"""Merges list of ModelFnOps for training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
||||||
@ -1388,14 +1484,14 @@ class _MultiHead(_Head):
|
|||||||
documentaion for more details.
|
documentaion for more details.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelFnOps that combines all the heads.
|
ModelFnOps that merges all heads for TRAIN.
|
||||||
"""
|
"""
|
||||||
losses = []
|
losses = []
|
||||||
additional_train_ops = []
|
additional_train_ops = []
|
||||||
for m in all_model_fn_ops:
|
for m in all_model_fn_ops:
|
||||||
losses.append(m.loss)
|
losses.append(m.loss)
|
||||||
additional_train_ops.append(m.train_op)
|
additional_train_ops.append(m.train_op)
|
||||||
loss = self._loss_combiner(losses)
|
loss = self._loss_merger(losses)
|
||||||
|
|
||||||
train_op = train_op_fn(loss)
|
train_op = train_op_fn(loss)
|
||||||
train_op = control_flow_ops.group(train_op, *additional_train_ops)
|
train_op = control_flow_ops.group(train_op, *additional_train_ops)
|
||||||
@ -1404,14 +1500,14 @@ class _MultiHead(_Head):
|
|||||||
loss=loss,
|
loss=loss,
|
||||||
train_op=train_op)
|
train_op=train_op)
|
||||||
|
|
||||||
def _combine_infer(self, all_model_fn_ops):
|
def _merge_infer(self, all_model_fn_ops):
|
||||||
"""Combines list of ModelFnOps for inference.
|
"""Merges list of ModelFnOps for inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelFnOps that combines all the heads.
|
ModelFnOps that Merges all the heads for INFER.
|
||||||
"""
|
"""
|
||||||
predictions = {}
|
predictions = {}
|
||||||
output_alternatives = {}
|
output_alternatives = {}
|
||||||
@ -1426,14 +1522,14 @@ class _MultiHead(_Head):
|
|||||||
predictions=predictions,
|
predictions=predictions,
|
||||||
output_alternatives=output_alternatives)
|
output_alternatives=output_alternatives)
|
||||||
|
|
||||||
def _combine_eval(self, all_model_fn_ops):
|
def _merge_eval(self, all_model_fn_ops):
|
||||||
"""Combines list of ModelFnOps for eval.
|
"""Merges list of ModelFnOps for eval.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
all_model_fn_ops: list of ModelFnOps for the individual heads.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelFnOps that combines all the heads.
|
ModelFnOps that merges all the heads for EVAL.
|
||||||
"""
|
"""
|
||||||
predictions = {}
|
predictions = {}
|
||||||
metrics = {}
|
metrics = {}
|
||||||
@ -1446,7 +1542,7 @@ class _MultiHead(_Head):
|
|||||||
for k, v in m.eval_metric_ops.items():
|
for k, v in m.eval_metric_ops.items():
|
||||||
# metrics["%s/%s" % (k, head_name)] = v
|
# metrics["%s/%s" % (k, head_name)] = v
|
||||||
metrics[k] = v
|
metrics[k] = v
|
||||||
loss = self._loss_combiner(losses)
|
loss = self._loss_merger(losses)
|
||||||
|
|
||||||
return model_fn.ModelFnOps(
|
return model_fn.ModelFnOps(
|
||||||
mode=model_fn.ModeKeys.EVAL,
|
mode=model_fn.ModeKeys.EVAL,
|
||||||
@ -1733,3 +1829,14 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
|
|||||||
predictions, labels=labels, thresholds=(threshold,),
|
predictions, labels=labels, thresholds=(threshold,),
|
||||||
weights=_float_weights_or_none(weights))
|
weights=_float_weights_or_none(weights))
|
||||||
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
|
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
# TODO(zakaria): Remove these aliases, See b/34751732
|
||||||
|
_regression_head = regression_head
|
||||||
|
_poisson_regression_head = poisson_regression_head
|
||||||
|
_multi_class_head = multi_class_head
|
||||||
|
_binary_svm_head = binary_svm_head
|
||||||
|
_multi_label_head = multi_label_head
|
||||||
|
_multi_head = multi_head
|
||||||
|
_Head = Head
|
||||||
|
@ -112,7 +112,7 @@ class PoissonHeadTest(test.TestCase):
|
|||||||
return sum(lpl)/len(lpl)
|
return sum(lpl)/len(lpl)
|
||||||
|
|
||||||
def testPoissonWithLogits(self):
|
def testPoissonWithLogits(self):
|
||||||
head = head_lib._poisson_regression_head()
|
head = head_lib.poisson_regression_head()
|
||||||
labels = ((0.,), (1.,), (1.,))
|
labels = ((0.,), (1.,), (1.,))
|
||||||
logits = ((0.,), (-1.,), (3.,))
|
logits = ((0.,), (-1.,), (3.,))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
@ -140,7 +140,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
|
|
||||||
# TODO(zakaria): test multilabel regression.
|
# TODO(zakaria): test multilabel regression.
|
||||||
def testRegressionWithLogits(self):
|
def testRegressionWithLogits(self):
|
||||||
head = head_lib._regression_head()
|
head = head_lib.regression_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -154,7 +154,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
||||||
|
|
||||||
def testRegressionWithInvalidLogits(self):
|
def testRegressionWithInvalidLogits(self):
|
||||||
head = head_lib._regression_head()
|
head = head_lib.regression_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||||
head.create_model_fn_ops(
|
head.create_model_fn_ops(
|
||||||
@ -165,7 +165,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
logits=((1., 1.), (1., 1.), (3., 1.)))
|
logits=((1., 1.), (1., 1.), (3., 1.)))
|
||||||
|
|
||||||
def testRegressionWithLogitsInput(self):
|
def testRegressionWithLogitsInput(self):
|
||||||
head = head_lib._regression_head()
|
head = head_lib.regression_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -183,7 +183,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
|
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
|
||||||
|
|
||||||
def testRegressionWithLogitsAndLogitsInput(self):
|
def testRegressionWithLogitsAndLogitsInput(self):
|
||||||
head = head_lib._regression_head()
|
head = head_lib.regression_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "Both logits and logits_input supplied"):
|
ValueError, "Both logits and logits_input supplied"):
|
||||||
@ -196,7 +196,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
logits=((1.,), (1.,), (3.,)))
|
logits=((1.,), (1.,), (3.,)))
|
||||||
|
|
||||||
def testRegressionEvalMode(self):
|
def testRegressionEvalMode(self):
|
||||||
head = head_lib._regression_head()
|
head = head_lib.regression_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -212,7 +212,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testRegressionWithLabelName(self):
|
def testRegressionWithLabelName(self):
|
||||||
label_name = "my_label"
|
label_name = "my_label"
|
||||||
head = head_lib._regression_head(label_name=label_name)
|
head = head_lib.regression_head(label_name=label_name)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -226,7 +226,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
||||||
|
|
||||||
def testRegressionWithWeights(self):
|
def testRegressionWithWeights(self):
|
||||||
head = head_lib._regression_head(weight_column_name="label_weight")
|
head = head_lib.regression_head(weight_column_name="label_weight")
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
weights = ((2.,), (5.,), (0.,))
|
weights = ((2.,), (5.,), (0.,))
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
@ -242,7 +242,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
model_fn_ops)
|
model_fn_ops)
|
||||||
|
|
||||||
def testRegressionWithCenteredBias(self):
|
def testRegressionWithCenteredBias(self):
|
||||||
head = head_lib._regression_head(enable_centered_bias=True)
|
head = head_lib.regression_head(enable_centered_bias=True)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -264,7 +264,7 @@ class RegressionHeadTest(test.TestCase):
|
|||||||
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
||||||
|
|
||||||
def testRegressionErrorInSparseTensorLabels(self):
|
def testRegressionErrorInSparseTensorLabels(self):
|
||||||
head = head_lib._regression_head()
|
head = head_lib.regression_head()
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
labels = sparse_tensor.SparseTensorValue(
|
labels = sparse_tensor.SparseTensorValue(
|
||||||
indices=((0, 0), (1, 0), (2, 0)),
|
indices=((0, 0), (1, 0), (2, 0)),
|
||||||
@ -317,7 +317,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelWithLogits(self):
|
def testMultiLabelWithLogits(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
@ -334,7 +334,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
n_classes = 2
|
n_classes = 2
|
||||||
labels = ((0, 1),)
|
labels = ((0, 1),)
|
||||||
logits = ((1., 0.),)
|
logits = ((1., 0.),)
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
@ -361,7 +361,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
def testMultiLabelWithInvalidLogits(self):
|
def testMultiLabelWithInvalidLogits(self):
|
||||||
head = head_lib._multi_label_head(n_classes=len(self._labels[0]) + 1)
|
head = head_lib.multi_label_head(n_classes=len(self._labels[0]) + 1)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||||
head.create_model_fn_ops(
|
head.create_model_fn_ops(
|
||||||
@ -370,7 +370,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelWithLogitsInput(self):
|
def testMultiLabelWithLogitsInput(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
@ -407,7 +407,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelWithLogitsAndLogitsInput(self):
|
def testMultiLabelWithLogitsAndLogitsInput(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -418,7 +418,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelEvalMode(self):
|
def testMultiLabelEvalMode(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
@ -434,7 +434,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassEvalModeWithLargeLogits(self):
|
def testMultiClassEvalModeWithLargeLogits(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
logits = ((2., 0., -1),)
|
logits = ((2., 0., -1),)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
@ -474,7 +474,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
def testMultiLabelWithLabelName(self):
|
def testMultiLabelWithLabelName(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
label_name = "my_label"
|
label_name = "my_label"
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
label_name=label_name,
|
label_name=label_name,
|
||||||
metric_class_ids=range(n_classes))
|
metric_class_ids=range(n_classes))
|
||||||
@ -491,7 +491,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelWithWeight(self):
|
def testMultiLabelWithWeight(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
weight_column_name="label_weight",
|
weight_column_name="label_weight",
|
||||||
metric_class_ids=range(n_classes))
|
metric_class_ids=range(n_classes))
|
||||||
@ -510,7 +510,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelWithCustomLoss(self):
|
def testMultiLabelWithCustomLoss(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
weight_column_name="label_weight",
|
weight_column_name="label_weight",
|
||||||
metric_class_ids=range(n_classes),
|
metric_class_ids=range(n_classes),
|
||||||
@ -530,7 +530,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelWithCenteredBias(self):
|
def testMultiLabelWithCenteredBias(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
enable_centered_bias=True,
|
enable_centered_bias=True,
|
||||||
metric_class_ids=range(n_classes))
|
metric_class_ids=range(n_classes))
|
||||||
@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelSparseTensorLabels(self):
|
def testMultiLabelSparseTensorLabels(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
labels = sparse_tensor.SparseTensorValue(
|
labels = sparse_tensor.SparseTensorValue(
|
||||||
@ -580,7 +580,7 @@ class MultiLabelHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiLabelSparseTensorLabelsTooFewClasses(self):
|
def testMultiLabelSparseTensorLabelsTooFewClasses(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_label_head(
|
head = head_lib.multi_label_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
# Set _logits_dimension (n_classes) to a lower value; if it's set to 1
|
# Set _logits_dimension (n_classes) to a lower value; if it's set to 1
|
||||||
# upfront, the class throws an error during initialization.
|
# upfront, the class throws an error during initialization.
|
||||||
@ -629,7 +629,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinaryClassificationWithLogits(self):
|
def testBinaryClassificationWithLogits(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
@ -644,7 +644,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||||
|
|
||||||
def testBinaryClassificationWithInvalidLogits(self):
|
def testBinaryClassificationWithInvalidLogits(self):
|
||||||
head = head_lib._multi_class_head(n_classes=len(self._labels) + 1)
|
head = head_lib.multi_class_head(n_classes=len(self._labels) + 1)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||||
head.create_model_fn_ops(
|
head.create_model_fn_ops(
|
||||||
@ -653,7 +653,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinaryClassificationWithLogitsInput(self):
|
def testBinaryClassificationWithLogitsInput(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
@ -682,7 +682,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
def testBinaryClassificationWithLogitsAndLogitsInput(self):
|
def testBinaryClassificationWithLogitsAndLogitsInput(self):
|
||||||
head = head_lib._multi_class_head(n_classes=2)
|
head = head_lib.multi_class_head(n_classes=2)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "Both logits and logits_input supplied"):
|
ValueError, "Both logits and logits_input supplied"):
|
||||||
@ -692,7 +692,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinaryClassificationEvalMode(self):
|
def testBinaryClassificationEvalMode(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
@ -709,7 +709,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinaryClassificationInferMode(self):
|
def testBinaryClassificationInferMode(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
@ -722,8 +722,8 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinaryClassificationInferMode_withWightColumn(self):
|
def testBinaryClassificationInferMode_withWightColumn(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib._multi_class_head(n_classes=n_classes,
|
head = head_lib.multi_class_head(n_classes=n_classes,
|
||||||
weight_column_name="label_weight")
|
weight_column_name="label_weight")
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
@ -738,7 +738,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testErrorInSparseTensorLabels(self):
|
def testErrorInSparseTensorLabels(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib._multi_class_head(n_classes=n_classes)
|
head = head_lib.multi_class_head(n_classes=n_classes)
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
labels = sparse_tensor.SparseTensorValue(
|
labels = sparse_tensor.SparseTensorValue(
|
||||||
indices=((0, 0), (1, 0), (2, 0)),
|
indices=((0, 0), (1, 0), (2, 0)),
|
||||||
@ -755,7 +755,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinaryClassificationWithLabelName(self):
|
def testBinaryClassificationWithLabelName(self):
|
||||||
label_name = "my_label"
|
label_name = "my_label"
|
||||||
head = head_lib._multi_class_head(n_classes=2, label_name=label_name)
|
head = head_lib.multi_class_head(n_classes=2, label_name=label_name)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
@ -774,7 +774,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinaryClassificationWithWeights(self):
|
def testBinaryClassificationWithWeights(self):
|
||||||
n_classes = 2
|
n_classes = 2
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes, weight_column_name="label_weight")
|
n_classes=n_classes, weight_column_name="label_weight")
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
weights = ((1.,), (0.,))
|
weights = ((1.,), (0.,))
|
||||||
@ -808,7 +808,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
model_fn_ops)
|
model_fn_ops)
|
||||||
|
|
||||||
def testBinaryClassificationWithCustomLoss(self):
|
def testBinaryClassificationWithCustomLoss(self):
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=2, weight_column_name="label_weight",
|
n_classes=2, weight_column_name="label_weight",
|
||||||
loss_fn=_sigmoid_cross_entropy)
|
loss_fn=_sigmoid_cross_entropy)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
@ -844,7 +844,7 @@ class BinaryClassificationHeadTest(test.TestCase):
|
|||||||
model_fn_ops)
|
model_fn_ops)
|
||||||
|
|
||||||
def testBinaryClassificationWithCenteredBias(self):
|
def testBinaryClassificationWithCenteredBias(self):
|
||||||
head = head_lib._multi_class_head(n_classes=2, enable_centered_bias=True)
|
head = head_lib.multi_class_head(n_classes=2, enable_centered_bias=True)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
|
||||||
@ -904,7 +904,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassWithLogits(self):
|
def testMultiClassWithLogits(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
@ -920,7 +920,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||||
|
|
||||||
def testMultiClassWithInvalidLogits(self):
|
def testMultiClassWithInvalidLogits(self):
|
||||||
head = head_lib._multi_class_head(n_classes=len(self._logits[0]) + 1)
|
head = head_lib.multi_class_head(n_classes=len(self._logits[0]) + 1)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||||
head.create_model_fn_ops(
|
head.create_model_fn_ops(
|
||||||
@ -928,7 +928,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
logits=self._logits)
|
logits=self._logits)
|
||||||
|
|
||||||
def testMultiClassWithNoneTrainOpFnInTrain(self):
|
def testMultiClassWithNoneTrainOpFnInTrain(self):
|
||||||
head = head_lib._multi_class_head(n_classes=3)
|
head = head_lib.multi_class_head(n_classes=3)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "train_op_fn can not be None in TRAIN mode"):
|
ValueError, "train_op_fn can not be None in TRAIN mode"):
|
||||||
@ -939,7 +939,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassWithLogitsInput(self):
|
def testMultiClassWithLogitsInput(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
@ -978,7 +978,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassWithLogitsAndLogitsInput(self):
|
def testMultiClassWithLogitsAndLogitsInput(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -989,7 +989,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassEvalMode(self):
|
def testMultiClassEvalMode(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
# logloss: z:label, x:logit
|
# logloss: z:label, x:logit
|
||||||
@ -1007,7 +1007,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassEvalModeWithLargeLogits(self):
|
def testMultiClassEvalModeWithLargeLogits(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes, metric_class_ids=range(n_classes))
|
n_classes=n_classes, metric_class_ids=range(n_classes))
|
||||||
logits = ((2., 0., -1),)
|
logits = ((2., 0., -1),)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
@ -1046,7 +1046,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassWithWeight(self):
|
def testMultiClassWithWeight(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
weight_column_name="label_weight",
|
weight_column_name="label_weight",
|
||||||
metric_class_ids=range(n_classes))
|
metric_class_ids=range(n_classes))
|
||||||
@ -1069,7 +1069,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultiClassWithCustomLoss(self):
|
def testMultiClassWithCustomLoss(self):
|
||||||
n_classes = 3
|
n_classes = 3
|
||||||
head = head_lib._multi_class_head(
|
head = head_lib.multi_class_head(
|
||||||
n_classes=n_classes,
|
n_classes=n_classes,
|
||||||
weight_column_name="label_weight",
|
weight_column_name="label_weight",
|
||||||
metric_class_ids=range(n_classes),
|
metric_class_ids=range(n_classes),
|
||||||
@ -1094,7 +1094,7 @@ class MultiClassHeadTest(test.TestCase):
|
|||||||
def testInvalidNClasses(self):
|
def testInvalidNClasses(self):
|
||||||
for n_classes in (None, -1, 0, 1):
|
for n_classes in (None, -1, 0, 1):
|
||||||
with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
|
with self.assertRaisesRegexp(ValueError, "n_classes must be > 1"):
|
||||||
head_lib._multi_class_head(n_classes=n_classes)
|
head_lib.multi_class_head(n_classes=n_classes)
|
||||||
|
|
||||||
|
|
||||||
class BinarySvmHeadTest(test.TestCase):
|
class BinarySvmHeadTest(test.TestCase):
|
||||||
@ -1116,7 +1116,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
self._expected_losses = (.5, 0.)
|
self._expected_losses = (.5, 0.)
|
||||||
|
|
||||||
def testBinarySVMWithLogits(self):
|
def testBinarySVMWithLogits(self):
|
||||||
head = head_lib._binary_svm_head()
|
head = head_lib.binary_svm_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -1134,7 +1134,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
def testBinarySVMWithInvalidLogits(self):
|
def testBinarySVMWithInvalidLogits(self):
|
||||||
head = head_lib._binary_svm_head()
|
head = head_lib.binary_svm_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||||
head.create_model_fn_ops(
|
head.create_model_fn_ops(
|
||||||
@ -1142,7 +1142,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
logits=np.ones((2, 2)))
|
logits=np.ones((2, 2)))
|
||||||
|
|
||||||
def testBinarySVMWithLogitsInput(self):
|
def testBinarySVMWithLogitsInput(self):
|
||||||
head = head_lib._binary_svm_head()
|
head = head_lib.binary_svm_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -1164,7 +1164,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
def testBinarySVMWithLogitsAndLogitsInput(self):
|
def testBinarySVMWithLogitsAndLogitsInput(self):
|
||||||
head = head_lib._binary_svm_head()
|
head = head_lib.binary_svm_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "Both logits and logits_input supplied"):
|
ValueError, "Both logits and logits_input supplied"):
|
||||||
@ -1177,7 +1177,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
logits=self._predictions)
|
logits=self._predictions)
|
||||||
|
|
||||||
def testBinarySVMEvalMode(self):
|
def testBinarySVMEvalMode(self):
|
||||||
head = head_lib._binary_svm_head()
|
head = head_lib.binary_svm_head()
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -1197,7 +1197,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
|
|
||||||
def testBinarySVMWithLabelName(self):
|
def testBinarySVMWithLabelName(self):
|
||||||
label_name = "my_label"
|
label_name = "my_label"
|
||||||
head = head_lib._binary_svm_head(label_name=label_name)
|
head = head_lib.binary_svm_head(label_name=label_name)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -1215,7 +1215,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
def testBinarySVMWithWeights(self):
|
def testBinarySVMWithWeights(self):
|
||||||
head = head_lib._binary_svm_head(weight_column_name="weights")
|
head = head_lib.binary_svm_head(weight_column_name="weights")
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
weights = (7., 11.)
|
weights = (7., 11.)
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
@ -1235,7 +1235,7 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
}, model_fn_ops)
|
}, model_fn_ops)
|
||||||
|
|
||||||
def testBinarySVMWithCenteredBias(self):
|
def testBinarySVMWithCenteredBias(self):
|
||||||
head = head_lib._binary_svm_head(enable_centered_bias=True)
|
head = head_lib.binary_svm_head(enable_centered_bias=True)
|
||||||
with ops.Graph().as_default(), session.Session():
|
with ops.Graph().as_default(), session.Session():
|
||||||
model_fn_ops = head.create_model_fn_ops(
|
model_fn_ops = head.create_model_fn_ops(
|
||||||
{},
|
{},
|
||||||
@ -1265,21 +1265,21 @@ class BinarySvmHeadTest(test.TestCase):
|
|||||||
class MultiHeadTest(test.TestCase):
|
class MultiHeadTest(test.TestCase):
|
||||||
|
|
||||||
def testInvalidHeads(self):
|
def testInvalidHeads(self):
|
||||||
named_head = head_lib._multi_class_head(
|
named_head = head_lib.multi_class_head(
|
||||||
n_classes=3, label_name="label", head_name="head1")
|
n_classes=3, label_name="label", head_name="head1")
|
||||||
unnamed_head = head_lib._multi_class_head(
|
unnamed_head = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label")
|
n_classes=4, label_name="label")
|
||||||
with self.assertRaisesRegexp(ValueError, "must have names"):
|
with self.assertRaisesRegexp(ValueError, "must have names"):
|
||||||
head_lib._multi_head((named_head, unnamed_head))
|
head_lib.multi_head((named_head, unnamed_head))
|
||||||
with self.assertRaisesRegexp(ValueError, "must be SingleHead"):
|
with self.assertRaisesRegexp(ValueError, "must be SingleHead"):
|
||||||
head_lib._multi_head((named_head, head_lib._multi_head((named_head,))))
|
head_lib.multi_head((named_head, head_lib.multi_head((named_head,))))
|
||||||
|
|
||||||
def testTrainWithNoneTrainOpFn(self):
|
def testTrainWithNoneTrainOpFn(self):
|
||||||
head1 = head_lib._multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib._multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib._multi_head((head1, head2))
|
head = head_lib.multi_head((head1, head2))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
@ -1294,11 +1294,11 @@ class MultiHeadTest(test.TestCase):
|
|||||||
logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
|
logits=((-0.7, 0.2, .1, .1, .1, .1, .1),))
|
||||||
|
|
||||||
def testTrain_withNoHeadWeights(self):
|
def testTrain_withNoHeadWeights(self):
|
||||||
head1 = head_lib._multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib._multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib._multi_head((head1, head2))
|
head = head_lib.multi_head((head1, head2))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
@ -1320,11 +1320,11 @@ class MultiHeadTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
||||||
|
|
||||||
def testTrain_withHeadWeights(self):
|
def testTrain_withHeadWeights(self):
|
||||||
head1 = head_lib._multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib._multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib._multi_head((head1, head2), (1, .5))
|
head = head_lib.multi_head((head1, head2), (1, .5))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
@ -1345,11 +1345,11 @@ class MultiHeadTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)
|
self.assertAlmostEqual(1.531, sess.run(model_fn_ops.loss), places=3)
|
||||||
|
|
||||||
def testTrain_withDictLogits(self):
|
def testTrain_withDictLogits(self):
|
||||||
head1 = head_lib._multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib._multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib._multi_head((head1, head2))
|
head = head_lib.multi_head((head1, head2))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
@ -1372,11 +1372,11 @@ class MultiHeadTest(test.TestCase):
|
|||||||
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
self.assertAlmostEqual(2.224, sess.run(model_fn_ops.loss), places=3)
|
||||||
|
|
||||||
def testInfer(self):
|
def testInfer(self):
|
||||||
head1 = head_lib._multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib._multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib._multi_head((head1, head2), (1, .5))
|
head = head_lib.multi_head((head1, head2), (1, .5))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
@ -1422,11 +1422,11 @@ class MultiHeadTest(test.TestCase):
|
|||||||
), model_fn_ops.output_alternatives["head2"][1].keys())
|
), model_fn_ops.output_alternatives["head2"][1].keys())
|
||||||
|
|
||||||
def testEval(self):
|
def testEval(self):
|
||||||
head1 = head_lib._multi_class_head(
|
head1 = head_lib.multi_class_head(
|
||||||
n_classes=3, label_name="label1", head_name="head1")
|
n_classes=3, label_name="label1", head_name="head1")
|
||||||
head2 = head_lib._multi_class_head(
|
head2 = head_lib.multi_class_head(
|
||||||
n_classes=4, label_name="label2", head_name="head2")
|
n_classes=4, label_name="label2", head_name="head2")
|
||||||
head = head_lib._multi_head((head1, head2), (1, .5))
|
head = head_lib.multi_head((head1, head2), (1, .5))
|
||||||
labels = {
|
labels = {
|
||||||
"label1": (1,),
|
"label1": (1,),
|
||||||
"label2": (1,)
|
"label2": (1,)
|
||||||
|
@ -419,7 +419,7 @@ class LinearClassifier(estimator.Estimator):
|
|||||||
enable_centered_bias = False
|
enable_centered_bias = False
|
||||||
logging.warning("centered_bias is not supported with SDCA, "
|
logging.warning("centered_bias is not supported with SDCA, "
|
||||||
"please disable it explicitly.")
|
"please disable it explicitly.")
|
||||||
head = head_lib._multi_class_head( # pylint: disable=protected-access
|
head = head_lib.multi_class_head(
|
||||||
n_classes,
|
n_classes,
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
enable_centered_bias=enable_centered_bias)
|
enable_centered_bias=enable_centered_bias)
|
||||||
@ -686,7 +686,7 @@ class LinearRegressor(estimator.Estimator):
|
|||||||
enable_centered_bias = False
|
enable_centered_bias = False
|
||||||
logging.warning("centered_bias is not supported with SDCA, "
|
logging.warning("centered_bias is not supported with SDCA, "
|
||||||
"please disable it explicitly.")
|
"please disable it explicitly.")
|
||||||
head = head_lib._regression_head( # pylint: disable=protected-access
|
head = head_lib.regression_head(
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
label_dimension=label_dimension,
|
label_dimension=label_dimension,
|
||||||
enable_centered_bias=enable_centered_bias)
|
enable_centered_bias=enable_centered_bias)
|
||||||
@ -824,8 +824,7 @@ class LinearRegressor(estimator.Estimator):
|
|||||||
exports_to_keep=exports_to_keep)
|
exports_to_keep=exports_to_keep)
|
||||||
|
|
||||||
|
|
||||||
# TODO(zakaria): Make it public when b/34751732 is fixed.
|
class LinearEstimator(estimator.Estimator):
|
||||||
class _LinearEstimator(estimator.Estimator):
|
|
||||||
"""Linear model with user specified head.
|
"""Linear model with user specified head.
|
||||||
|
|
||||||
Train a generalized linear model to predict label value given observation of
|
Train a generalized linear model to predict label value given observation of
|
||||||
@ -840,9 +839,9 @@ class _LinearEstimator(estimator.Estimator):
|
|||||||
|
|
||||||
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
sparse_feature_a_x_sparse_feature_b = crossed_column(...)
|
||||||
|
|
||||||
estimator = _LinearEstimator(
|
estimator = LinearEstimator(
|
||||||
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
feature_columns=[sparse_column_a, sparse_feature_a_x_sparse_feature_b],
|
||||||
head=head_lib._poisson_regression_head())
|
head=head_lib.poisson_regression_head())
|
||||||
|
|
||||||
# Input builders
|
# Input builders
|
||||||
def input_fn_train: # returns x, y
|
def input_fn_train: # returns x, y
|
||||||
@ -879,7 +878,7 @@ class _LinearEstimator(estimator.Estimator):
|
|||||||
_joint_weights=False,
|
_joint_weights=False,
|
||||||
config=None,
|
config=None,
|
||||||
feature_engineering_fn=None):
|
feature_engineering_fn=None):
|
||||||
"""Construct a `_LinearEstimator` object.
|
"""Construct a `LinearEstimator` object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
feature_columns: An iterable containing all the feature columns used by
|
feature_columns: An iterable containing all the feature columns used by
|
||||||
@ -907,14 +906,14 @@ class _LinearEstimator(estimator.Estimator):
|
|||||||
into the model.
|
into the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `_LinearEstimator` estimator.
|
A `LinearEstimator` estimator.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if optimizer is not supported, e.g., SDCAOptimizer
|
ValueError: if optimizer is not supported, e.g., SDCAOptimizer
|
||||||
"""
|
"""
|
||||||
assert feature_columns
|
assert feature_columns
|
||||||
if isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
|
if isinstance(optimizer, sdca_optimizer.SDCAOptimizer):
|
||||||
raise ValueError("_LinearEstimator does not support SDCA optimizer.")
|
raise ValueError("LinearEstimator does not support SDCA optimizer.")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"head": head,
|
"head": head,
|
||||||
@ -923,7 +922,7 @@ class _LinearEstimator(estimator.Estimator):
|
|||||||
"gradient_clip_norm": gradient_clip_norm,
|
"gradient_clip_norm": gradient_clip_norm,
|
||||||
"joint_weights": _joint_weights,
|
"joint_weights": _joint_weights,
|
||||||
}
|
}
|
||||||
super(_LinearEstimator, self).__init__(
|
super(LinearEstimator, self).__init__(
|
||||||
model_fn=_linear_model_fn,
|
model_fn=_linear_model_fn,
|
||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -1665,15 +1665,15 @@ class LinearEstimatorTest(test.TestCase):
|
|||||||
'feature', dimension=4)
|
'feature', dimension=4)
|
||||||
]
|
]
|
||||||
exp = experiment.Experiment(
|
exp = experiment.Experiment(
|
||||||
estimator=linear._LinearEstimator(feature_columns=cont_features,
|
estimator=linear.LinearEstimator(feature_columns=cont_features,
|
||||||
head=head_lib._regression_head()),
|
head=head_lib.regression_head()),
|
||||||
train_input_fn=test_data.iris_input_logistic_fn,
|
train_input_fn=test_data.iris_input_logistic_fn,
|
||||||
eval_input_fn=test_data.iris_input_logistic_fn)
|
eval_input_fn=test_data.iris_input_logistic_fn)
|
||||||
exp.test()
|
exp.test()
|
||||||
|
|
||||||
def testEstimatorContract(self):
|
def testEstimatorContract(self):
|
||||||
estimator_test_utils.assert_estimator_contract(self,
|
estimator_test_utils.assert_estimator_contract(self,
|
||||||
linear._LinearEstimator)
|
linear.LinearEstimator)
|
||||||
|
|
||||||
def testLinearRegression(self):
|
def testLinearRegression(self):
|
||||||
"""Tests that loss goes down with training."""
|
"""Tests that loss goes down with training."""
|
||||||
@ -1691,8 +1691,8 @@ class LinearEstimatorTest(test.TestCase):
|
|||||||
100)
|
100)
|
||||||
age = feature_column_lib.real_valued_column('age')
|
age = feature_column_lib.real_valued_column('age')
|
||||||
|
|
||||||
linear_estimator = linear._LinearEstimator(feature_columns=[age, language],
|
linear_estimator = linear.LinearEstimator(feature_columns=[age, language],
|
||||||
head=head_lib._regression_head())
|
head=head_lib.regression_head())
|
||||||
linear_estimator.fit(input_fn=input_fn, steps=100)
|
linear_estimator.fit(input_fn=input_fn, steps=100)
|
||||||
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
|
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
|
||||||
linear_estimator.fit(input_fn=input_fn, steps=400)
|
linear_estimator.fit(input_fn=input_fn, steps=400)
|
||||||
@ -1717,9 +1717,9 @@ class LinearEstimatorTest(test.TestCase):
|
|||||||
100)
|
100)
|
||||||
age = feature_column_lib.real_valued_column('age')
|
age = feature_column_lib.real_valued_column('age')
|
||||||
|
|
||||||
linear_estimator = linear._LinearEstimator(
|
linear_estimator = linear.LinearEstimator(
|
||||||
feature_columns=[age, language],
|
feature_columns=[age, language],
|
||||||
head=head_lib._poisson_regression_head())
|
head=head_lib.poisson_regression_head())
|
||||||
linear_estimator.fit(input_fn=input_fn, steps=10)
|
linear_estimator.fit(input_fn=input_fn, steps=10)
|
||||||
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
|
loss1 = linear_estimator.evaluate(input_fn=input_fn, steps=1)['loss']
|
||||||
linear_estimator.fit(input_fn=input_fn, steps=100)
|
linear_estimator.fit(input_fn=input_fn, steps=100)
|
||||||
@ -1736,8 +1736,8 @@ class LinearEstimatorTest(test.TestCase):
|
|||||||
sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
|
sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
|
||||||
example_id_column='example_id')
|
example_id_column='example_id')
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
linear._LinearEstimator(
|
linear.LinearEstimator(
|
||||||
head=head_lib._regression_head(label_dimension=1),
|
head=head_lib.regression_head(label_dimension=1),
|
||||||
feature_columns=[maintenance_cost, sq_footage],
|
feature_columns=[maintenance_cost, sq_footage],
|
||||||
optimizer=sdca_optimizer,
|
optimizer=sdca_optimizer,
|
||||||
_joint_weights=True)
|
_joint_weights=True)
|
||||||
|
@ -139,7 +139,7 @@ class SVM(estimator.Estimator):
|
|||||||
model_dir=model_dir,
|
model_dir=model_dir,
|
||||||
config=config,
|
config=config,
|
||||||
params={
|
params={
|
||||||
"head": head_lib._binary_svm_head( # pylint: disable=protected-access
|
"head": head_lib.binary_svm_head(
|
||||||
weight_column_name=weight_column_name,
|
weight_column_name=weight_column_name,
|
||||||
enable_centered_bias=False),
|
enable_centered_bias=False),
|
||||||
"feature_columns": feature_columns,
|
"feature_columns": feature_columns,
|
||||||
|
Loading…
Reference in New Issue
Block a user