From 43a3c393d7a329b7dc7aec02a7d46dc69e5a8ee1 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Thu, 6 Sep 2018 10:02:24 -0700 Subject: [PATCH] Update docstring for BoostedTrees n_batches_per_layer. PiperOrigin-RevId: 211824645 --- tensorflow/python/estimator/canned/boosted_trees.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index d104c961d3f..19f18015e42 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -1000,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE classifier = estimator.BoostedTreesClassifier( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... ) @@ -1024,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. @@ -1138,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE regressor = estimator.BoostedTreesRegressor( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... ) @@ -1162,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.