From 0ac688a18cc56816d8c767f7fcbce97b05b2319e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Oct 2017 13:21:22 -0700 Subject: [PATCH] Adding a binary classification example PiperOrigin-RevId: 171577979 --- tensorflow/contrib/boosted_trees/README.md | 11 ++ .../boosted_trees/examples/binary_mnist.py | 169 ++++++++++++++++++ .../contrib/boosted_trees/examples/boston.py | 2 - 3 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 tensorflow/contrib/boosted_trees/README.md create mode 100644 tensorflow/contrib/boosted_trees/examples/binary_mnist.py diff --git a/tensorflow/contrib/boosted_trees/README.md b/tensorflow/contrib/boosted_trees/README.md new file mode 100644 index 00000000000..9ce700f1a19 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/README.md @@ -0,0 +1,11 @@ +# TF Boosted Trees (TFBT) + +TF Boosted trees is an implementation of a gradient boosting algorithm with +trees used as week learners. + +## Examples +Folder "examples" demonstrates how TFBT estimators can be used for various +problems. Namely, it contains: +* binary_mnist.py - an example on how to use TFBT for binary classification. +* mnist.py - a multiclass example. +* boston.py - a regression example. \ No newline at end of file diff --git a/tensorflow/contrib/boosted_trees/examples/binary_mnist.py b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py new file mode 100644 index 00000000000..9be362f5c83 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py @@ -0,0 +1,169 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates multiclass MNIST TF Boosted trees example. + + This example demonstrates how to run experiments with TF Boosted Trees on + a binary dataset. We use digits 4 and 9 from the original MNIST dataset. + + Example Usage: + python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \ + --output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \ + --batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \ + --num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1 \ + + When training is done, accuracy on eval data is reported. Point tensorboard + to the directory for the run to see how the training progresses: + + tensorboard --logdir=/tmp/binary_mnist + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.learn import learn_runner + + +def get_input_fn(data, + batch_size, + capacity=10000, + min_after_dequeue=3000): + """Input function over MNIST data.""" + # Keep only 4 and 9 digits. + ids = np.where((data.labels == 4) | (data.labels == 9)) + images = data.images[ids] + labels = data.labels[ids] + # Make digit 4 label 0, 9 is 1. + labels = labels == 4 + + def _input_fn(): + """Prepare features and labels.""" + images_batch, labels_batch = tf.train.shuffle_batch( + tensors=[images, + labels.astype(np.int32)], + batch_size=batch_size, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + enqueue_many=True, + num_threads=4) + features_map = {"images": images_batch} + return features_map, labels_batch + + return _input_fn + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer + learner_config.constraints.max_tree_depth = FLAGS.depth + + growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + learner_config.growing_mode = growing_mode + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + # Create a TF Boosted trees estimator that can take in custom loss. + estimator = GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + examples_per_layer=FLAGS.examples_per_layer, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + data = tf.contrib.learn.datasets.mnist.load_mnist() + train_input_fn = get_input_fn(data.train, FLAGS.batch_size) + eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--eval_batch_size", + type=int, + default=1000, + help="Size of the batch for eval.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--examples_per_layer", + type=int, + default=1000, + help="Number of examples to accumulate stats for per layer.") + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index 0cb9e956ef6..2c0a3c4912b 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -44,8 +44,6 @@ from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn import learn_runner -_TEST_SPLIT_RATIO = 0.2 -_TEST_SPLIT_SEED = 42 _BOSTON_NUM_FEATURES = 13