Fixes a bug in setting default optimizers for DNNLinearCombinedClassifier.

PiperOrigin-RevId: 158190192
This commit is contained in:
A. Unique TensorFlower 2017-06-06 14:27:21 -07:00 committed by TensorFlower Gardener
parent 3ca6533049
commit a4e7b7add4

View File

@ -300,9 +300,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
def __init__(self, def __init__(self,
model_dir=None, model_dir=None,
linear_feature_columns=None, linear_feature_columns=None,
linear_optimizer=None, linear_optimizer='Ftrl',
dnn_feature_columns=None, dnn_feature_columns=None,
dnn_optimizer=None, dnn_optimizer='Adagrad',
dnn_hidden_units=None, dnn_hidden_units=None,
dnn_activation_fn=nn.relu, dnn_activation_fn=nn.relu,
dnn_dropout=None, dnn_dropout=None,
@ -319,12 +319,12 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
used by linear part of the model. All items in the set must be used by linear part of the model. All items in the set must be
instances of classes derived from `FeatureColumn`. instances of classes derived from `FeatureColumn`.
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
the linear part of the model. If `None`, will use a FTRL optimizer. the linear part of the model. Defaults to FTRL optimizer.
dnn_feature_columns: An iterable containing all the feature columns used dnn_feature_columns: An iterable containing all the feature columns used
by deep part of the model. All items in the set must be instances of by deep part of the model. All items in the set must be instances of
classes derived from `FeatureColumn`. classes derived from `FeatureColumn`.
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
the deep part of the model. If `None`, will use an Adagrad optimizer. the deep part of the model. Defaults to Adagrad optimizer.
dnn_hidden_units: List of hidden units per layer. All layers are fully dnn_hidden_units: List of hidden units per layer. All layers are fully
connected. connected.
dnn_activation_fn: Activation function applied to each layer. If None, dnn_activation_fn: Activation function applied to each layer. If None,