Update docstring for BoostedTrees n_batches_per_layer.

PiperOrigin-RevId: 211824645
This commit is contained in:
Zhenyu Tan 2018-09-06 10:02:24 -07:00 committed by TensorFlower Gardener
parent d17016a8df
commit 43a3c393d7

View File

@ -1000,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator):
bucketized_feature_2 = bucketized_column( bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2) numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
# Need to see a large portion of the data before we can build a layer, for
# example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
classifier = estimator.BoostedTreesClassifier( classifier = estimator.BoostedTreesClassifier(
feature_columns=[bucketized_feature_1, bucketized_feature_2], feature_columns=[bucketized_feature_1, bucketized_feature_2],
n_batches_per_layer=n_batches_per_layer,
n_trees=100, n_trees=100,
... <some other params> ... <some other params>
) )
@ -1024,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator):
the model. All items in the set should be instances of classes derived the model. All items in the set should be instances of classes derived
from `FeatureColumn`. from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per n_batches_per_layer: the number of batches to collect statistics per
layer. layer. The total number of batches is total number of data divided by
batch size.
model_dir: Directory to save model parameters, graph and etc. This can model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model. to continue training a previously saved model.
@ -1138,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator):
bucketized_feature_2 = bucketized_column( bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2) numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
# Need to see a large portion of the data before we can build a layer, for
# example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
regressor = estimator.BoostedTreesRegressor( regressor = estimator.BoostedTreesRegressor(
feature_columns=[bucketized_feature_1, bucketized_feature_2], feature_columns=[bucketized_feature_1, bucketized_feature_2],
n_batches_per_layer=n_batches_per_layer,
n_trees=100, n_trees=100,
... <some other params> ... <some other params>
) )
@ -1162,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator):
the model. All items in the set should be instances of classes derived the model. All items in the set should be instances of classes derived
from `FeatureColumn`. from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per n_batches_per_layer: the number of batches to collect statistics per
layer. layer. The total number of batches is total number of data divided by
batch size.
model_dir: Directory to save model parameters, graph and etc. This can model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model. to continue training a previously saved model.