From c2d9cb1d0856aae2ee1d37cb4e23cddbb8a88c79 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Apr 2016 08:26:31 -0800 Subject: [PATCH] Fixes and enhancements for contrib/tensorforest: - remove flags (pushed to internal client) - thread-parallel execution for CountExtremeleyRandomStats op. - critical time-seeding fix. Change: 120129936 --- .../ops/count_extremely_random_stats_op.cc | 150 +++++++++++++----- .../core/ops/sample_inputs_op.cc | 2 +- .../count_extremely_random_stats_op_test.py | 25 ++- .../tensor_forest/python/ops/inference_ops.py | 10 +- .../tensor_forest/python/ops/training_ops.py | 9 +- .../tensor_forest/python/tensor_forest.py | 80 ++++------ 6 files changed, 170 insertions(+), 106 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc index bd2cd59eea8..2f1602d0ba5 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc @@ -18,6 +18,7 @@ // only op that involves tree traversal, and is constructed so that it can // be run in parallel on separate batches of data. #include +#include #include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h" @@ -25,10 +26,12 @@ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { using std::get; +using std::make_pair; using std::make_tuple; using std::pair; using std::tuple; @@ -42,6 +45,71 @@ using tensorforest::DecideNode; using tensorforest::Initialize; using tensorforest::IsAllInitialized; +// A data structure to store the results of parallel tree traversal. +struct InputDataResult { + // A list of each node that was visited. + std::vector node_indices; + // The accumulator of the leaf that a data point ended up at, or -1 if none. + int32 leaf_accumulator; + // The left-branch taken candidate splits. + std::vector split_adds; + // If the candidate splits for the leaf that a data point arrived at + // were initialized or not, which determines if we add this to total + // pcw counts or not. + bool splits_initialized; +}; + +void Evaluate(const Tensor& input_data, const Tensor& input_labels, + const Tensor& tree_tensor, const Tensor& tree_thresholds, + const Tensor& node_to_accumulator, + const Tensor& candidate_split_features, + const Tensor& candidate_split_thresholds, + InputDataResult* results, int64 start, int64 end) { + const auto tree = tree_tensor.tensor(); + const auto thresholds = tree_thresholds.unaligned_flat(); + const auto node_map = node_to_accumulator.unaligned_flat(); + const auto split_features = candidate_split_features.tensor(); + const auto split_thresholds = candidate_split_thresholds.tensor(); + + const int32 num_splits = candidate_split_features.shape().dim_size(1); + + for (int i = start; i < end; ++i) { + const Tensor point = input_data.Slice(i, i + 1); + int node_index = 0; + results[i].splits_initialized = false; + while (true) { + results[i].node_indices.push_back(node_index); + int32 left_child = tree(node_index, CHILDREN_INDEX); + if (left_child == LEAF_NODE) { + const int32 accumulator = node_map(node_index); + results[i].leaf_accumulator = accumulator; + // If the leaf is not fertile or is not yet initialized, we don't + // count it in the candidate/total split per-class-weights because + // it won't have any candidate splits yet. + if (accumulator >= 0 && + IsAllInitialized(candidate_split_features.Slice( + accumulator, accumulator + 1))) { + results[i].splits_initialized = true; + for (int split = 0; split < num_splits; split++) { + if (!DecideNode(point, split_features(accumulator, split), + split_thresholds(accumulator, split))) { + results[i].split_adds.push_back(split); + } + } + } + break; + } else if (left_child == FREE_NODE) { + LOG(ERROR) << "Reached a free node, not good."; + results[i].node_indices.push_back(FREE_NODE); + break; + } + node_index = + left_child + DecideNode(point, tree(node_index, FEATURE_INDEX), + thresholds(node_index)); + } + } +} + REGISTER_OP("CountExtremelyRandomStats") .Attr("num_classes: int32") .Input("input_data: float") @@ -176,7 +244,31 @@ class CountExtremelyRandomStats : public OpKernel { "candidate_split_features and candidate_split_thresholds should be " "the same shape.")); - const int32 num_splits = candidate_split_features.shape().dim_size(1); + // Evaluate input data in parallel. + const int64 num_data = input_data.shape().dim_size(0); + std::unique_ptr results(new InputDataResult[num_data]); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + int num_threads = worker_threads->num_threads; + if (num_threads <= 1) { + Evaluate(input_data, input_labels, tree_tensor, tree_thresholds, + node_to_accumulator, candidate_split_features, + candidate_split_thresholds, results.get(), 0, num_data); + } else { + auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds, + &node_to_accumulator, &candidate_split_features, + &candidate_split_thresholds, &num_data, + &results](int64 start, int64 end) { + CHECK(start <= end); + CHECK(end <= num_data); + Evaluate(input_data, input_labels, tree_tensor, tree_thresholds, + node_to_accumulator, candidate_split_features, + candidate_split_thresholds, results.get(), start, end); + }; + Shard(num_threads, worker_threads->workers, num_data, 100, work); + } + + // Set output tensors. + const auto labels = input_labels.unaligned_flat(); // node pcw delta Tensor* output_node_pcw_delta = nullptr; @@ -196,58 +288,28 @@ class CountExtremelyRandomStats : public OpKernel { &output_leaves)); auto out_leaves = output_leaves->unaligned_flat(); - const auto tree = tree_tensor.tensor(); - const auto thresholds = tree_thresholds.unaligned_flat(); - const auto labels = input_labels.unaligned_flat(); - const auto node_map = node_to_accumulator.unaligned_flat(); - const auto split_features = candidate_split_features.tensor(); - const auto split_thresholds = candidate_split_thresholds.tensor(); - - const int32 num_data = input_data.shape().dim_size(0); - // -> count delta std::unordered_map, int32, PairIntHash> total_delta; // -> count delta std::unordered_map, int32, TupleIntHash> split_delta; - for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); - int node_index = 0; - while (true) { - const int32 label = labels(i); - ++out_node(node_index, label); - int32 left_child = tree(node_index, CHILDREN_INDEX); - if (left_child == LEAF_NODE) { - out_leaves(i) = node_index; - const int32 accumulator = node_map(node_index); - // If the leaf is not fertile or is not yet initialized, we don't - // count it in the candidate/total split per-class-weights because - // it won't have any candidate splits yet. - if (accumulator >= 0 && - IsAllInitialized( - candidate_split_features.Slice(accumulator, - accumulator + 1))) { - ++total_delta[std::make_pair(accumulator, label)]; - for (int split = 0; split < num_splits; split++) { - if (!DecideNode(point, split_features(accumulator, split), - split_thresholds(accumulator, split))) { - ++split_delta[make_tuple(accumulator, split, label)]; - } - } - } - break; - } else if (left_child == FREE_NODE) { - LOG(ERROR) << "Reached a free node, not good."; - out_leaves(i) = FREE_NODE; - break; + + for (int32 i = 0; i < num_data; ++i) { + const int32 label = labels(i); + const int32 accumulator = results[i].leaf_accumulator; + for (const int32 node : results[i].node_indices) { + ++out_node(node, label); + } + out_leaves(i) = results[i].node_indices.back(); + if (accumulator >= 0 && results[i].splits_initialized) { + ++total_delta[make_pair(accumulator, label)]; + for (const int32 split : results[i].split_adds) { + ++split_delta[make_tuple(accumulator, split, label)]; } - node_index = left_child + - DecideNode(point, tree(node_index, FEATURE_INDEX), - thresholds(node_index)); } } - // candidate splits pcw indices + // candidate splits pcw indices Tensor* output_candidate_pcw_indices = nullptr; TensorShape candidate_pcw_shape; candidate_pcw_shape.AddDim(split_delta.size()); diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc index 452a68353c8..90ed0420e3a 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc @@ -94,7 +94,7 @@ class SampleInputs : public OpKernel { "split_sampling_random_seed", &split_sampling_random_seed_)); // Set up the random number generator. if (split_sampling_random_seed_ == 0) { - uint64 time_seed = static_cast(std::time(NULL)); + uint64 time_seed = static_cast(std::clock()); single_rand_ = std::unique_ptr( new random::PhiloxRandom(time_seed)); } else { diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py index e93bb17522a..5f2a2e49ef6 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py @@ -17,7 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import tensorflow # pylint: disable=unused-import +import tensorflow as tf from tensorflow.contrib.tensor_forest.python.ops import training_ops @@ -47,6 +47,29 @@ class CountExtremelyRandomStatsTest(test_util.TensorFlowTestCase): self.tree_thresholds, self.node_map, self.split_features, self.split_thresholds, num_classes=4)) + self.assertAllEqual( + [[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]], + pcw_node.eval()) + self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval()) + self.assertAllEqual([1.], pcw_splits_delta.eval()) + self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval()) + self.assertAllEqual([1., 1.], pcw_totals_delta.eval()) + self.assertAllEqual([1, 1, 2, 2], leaves.eval()) + + def testThreaded(self): + with self.test_session( + config=tf.ConfigProto(intra_op_parallelism_threads=2)): + (pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices, + pcw_totals_delta, + leaves) = (self.ops.count_extremely_random_stats(self.input_data, + self.input_labels, + self.tree, + self.tree_thresholds, + self.node_map, + self.split_features, + self.split_thresholds, + num_classes=4)) + self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]], pcw_node.eval()) diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py index bcf6ca6b6ea..62add1bf6ce 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py @@ -25,12 +25,6 @@ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -flags = tf.app.flags -FLAGS = flags.FLAGS - -flags.DEFINE_string('inference_library_base_dir', '', - 'Directory to look for inference library file.') - INFERENCE_OPS_FILE = '_inference_ops.so' _inference_ops = None @@ -55,12 +49,12 @@ def TreePredictions(op): # there's not yet any guarantee that the shared object exists. # In which case, "import tensorflow" will always crash, even for users that # never use contrib. -def Load(): +def Load(library_base_dir=''): """Load the inference ops library and return the loaded module.""" with _ops_lock: global _inference_ops if not _inference_ops: - data_files_path = os.path.join(FLAGS.inference_library_base_dir, + data_files_path = os.path.join(library_base_dir, tf.resource_loader.get_data_files_path()) tf.logging.info('data path: %s', data_files_path) _inference_ops = tf.load_op_library(os.path.join( diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py index 5cf5e4af908..84bc2cfea6f 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py @@ -25,11 +25,6 @@ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -flags = tf.app.flags -FLAGS = flags.FLAGS - -flags.DEFINE_string('training_library_base_dir', '', - 'Directory to look for inference library file.') TRAINING_OPS_FILE = '_training_ops.so' @@ -102,12 +97,12 @@ def _UpdateFertileSlotsShape(unused_op): # there's not yet any guarantee that the shared object exists. # In which case, "import tensorflow" will always crash, even for users that # never use contrib. -def Load(): +def Load(library_base_dir=''): """Load training ops library and return the loaded module.""" with _ops_lock: global _training_ops if not _training_ops: - data_files_path = os.path.join(FLAGS.training_library_base_dir, + data_files_path = os.path.join(library_base_dir, tf.resource_loader.get_data_files_path()) tf.logging.info('data path: %s', data_files_path) _training_ops = tf.load_op_library(os.path.join( diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 45e8cab4857..3d254f2d505 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -25,26 +25,6 @@ from tensorflow.contrib.tensor_forest.python.ops import inference_ops from tensorflow.contrib.tensor_forest.python.ops import training_ops -flags = tf.app.flags -FLAGS = flags.FLAGS - - -# Default parameter values. These are all only used if the corresponding -# parameter is not specified when constructing the ForestHParams. -flags.DEFINE_integer('num_trees', 100, 'Number of trees in forest') -flags.DEFINE_integer('max_nodes', 10000, 'Maxmimum number of tree nodes.') -flags.DEFINE_float( - 'samples_to_decide', 25.0, - 'Only decide on a split, or only fully use a leaf, after this many ' - 'training samples have been seen.') -flags.DEFINE_float('bagging_fraction', 1.0, - 'Use this fraction of the input, randomly chosen, to train ' - 'each tree in the forest.') -flags.DEFINE_integer( - 'num_splits_to_consider', 0, - 'If non-zero, consider this many candidates for a splitting ' - 'rule at a fertile node.') - # If tree[i][0] equals this value, then i is a leaf node. LEAF_NODE = -1 @@ -64,7 +44,20 @@ LEAF_NODE = -1 class ForestHParams(object): """A base class for holding hyperparameters and calculating good defaults.""" - def __init__(self, **kwargs): + def __init__(self, num_trees=100, max_nodes=10000, bagging_fraction=1.0, + samples_to_decide=25, max_depth=0, num_splits_to_consider=0, + max_fertile_nodes=0, split_after_samples=0, + valid_leaf_threshold=0, **kwargs): + self.num_trees = num_trees + self.max_nodes = max_nodes + self.bagging_fraction = bagging_fraction + self.samples_to_decide = samples_to_decide + self.max_depth = max_depth + self.num_splits_to_consider = num_splits_to_consider + self.max_fertile_nodes = max_fertile_nodes + self.split_after_samples = split_after_samples + self.valid_leaf_threshold = valid_leaf_threshold + for name, value in kwargs.items(): setattr(self, name, value) @@ -76,24 +69,21 @@ class ForestHParams(object): # Fail fast if num_classes isn't set. _ = getattr(self, 'num_classes') - self.bagging_fraction = getattr(self, 'bagging_fraction', - FLAGS.bagging_fraction) - - self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees) - self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes) + self.training_library_base_dir = getattr( + self, 'training_library_base_dir', '') + self.inference_library_base_dir = getattr( + self, 'inference_library_base_dir', '') # Allow each tree to be unbalanced by up to a factor of 2. - self.max_depth = getattr(self, 'max_depth', - int(2 * math.ceil(math.log(self.max_nodes, 2)))) + self.max_depth = (self.max_depth or + int(2 * math.ceil(math.log(self.max_nodes, 2)))) # The Random Forest literature recommends sqrt(# features) for # classification problems, and p/3 for regression problems. # TODO(thomaswc): Consider capping this for large number of features. - self.num_splits_to_consider = getattr(self, 'num_splits_to_consider', - FLAGS.num_splits_to_consider) - if not self.num_splits_to_consider: - self.num_splits_to_consider = max(10, int( - math.ceil(math.sqrt(self.num_features)))) + self.num_splits_to_consider = ( + self.num_splits_to_consider or + max(10, int(math.ceil(math.sqrt(self.num_features))))) # max_fertile_nodes doesn't effect performance, only training speed. # We therefore set it primarily based upon space considerations. @@ -103,7 +93,7 @@ class ForestHParams(object): num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider)) # But always use at least 1000 accumulate slots. num_fertile = max(num_fertile, 1000) - self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile) + self.max_fertile_nodes = self.max_fertile_nodes or num_fertile # But it also never needs to be larger than the number of leaves, # which is max_nodes / 2. self.max_fertile_nodes = min(self.max_fertile_nodes, @@ -111,14 +101,11 @@ class ForestHParams(object): # split_after_samples and valid_leaf_threshold should be about the same. # Therefore, if either is set, use it to set the other. Otherwise, fall - # back on FLAGS.samples_to_decide. - samples_to_decide = ( - getattr(self, 'split_after_samples', - getattr(self, 'valid_leaf_threshold', FLAGS.samples_to_decide))) - self.split_after_samples = getattr(self, 'split_after_samples', - samples_to_decide) - self.valid_leaf_threshold = getattr(self, 'valid_leaf_threshold', - samples_to_decide) + # back on samples_to_decide. + samples_to_decide = self.split_after_samples or self.samples_to_decide + + self.split_after_samples = self.split_after_samples or samples_to_decide + self.valid_leaf_threshold = self.valid_leaf_threshold or samples_to_decide # We have num_splits_to_consider slots to fill, and we want to spend # approximately split_after_samples samples initializing them. @@ -249,9 +236,12 @@ class RandomForestGraphs(object): tf.logging.info(self.params.__dict__) self.variables = variables or ForestTrainingVariables( self.params, device_assigner=self.device_assigner) - self.trees = [RandomTreeGraphs(self.variables[i], self.params, - training_ops.Load(), inference_ops.Load()) - for i in range(self.params.num_trees)] + self.trees = [ + RandomTreeGraphs( + self.variables[i], self.params, + training_ops.Load(self.params.training_library_base_dir), + inference_ops.Load(self.params.inference_library_base_dir)) + for i in range(self.params.num_trees)] def training_graph(self, input_data, input_labels): """Constructs a TF graph for training a random forest.