Fixes and enhancements for contrib/tensorforest:

- remove flags (pushed to internal client)
- thread-parallel execution for CountExtremeleyRandomStats op.
- critical time-seeding fix.
Change: 120129936
This commit is contained in:
A. Unique TensorFlower 2016-04-18 08:26:31 -08:00 committed by TensorFlower Gardener
parent 449ecb561f
commit c2d9cb1d08
6 changed files with 170 additions and 106 deletions

View File

@ -18,6 +18,7 @@
// only op that involves tree traversal, and is constructed so that it can
// be run in parallel on separate batches of data.
#include <unordered_map>
#include <vector>
#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
@ -25,10 +26,12 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
using std::get;
using std::make_pair;
using std::make_tuple;
using std::pair;
using std::tuple;
@ -42,6 +45,71 @@ using tensorforest::DecideNode;
using tensorforest::Initialize;
using tensorforest::IsAllInitialized;
// A data structure to store the results of parallel tree traversal.
struct InputDataResult {
// A list of each node that was visited.
std::vector<int32> node_indices;
// The accumulator of the leaf that a data point ended up at, or -1 if none.
int32 leaf_accumulator;
// The left-branch taken candidate splits.
std::vector<int32> split_adds;
// If the candidate splits for the leaf that a data point arrived at
// were initialized or not, which determines if we add this to total
// pcw counts or not.
bool splits_initialized;
};
void Evaluate(const Tensor& input_data, const Tensor& input_labels,
const Tensor& tree_tensor, const Tensor& tree_thresholds,
const Tensor& node_to_accumulator,
const Tensor& candidate_split_features,
const Tensor& candidate_split_thresholds,
InputDataResult* results, int64 start, int64 end) {
const auto tree = tree_tensor.tensor<int32, 2>();
const auto thresholds = tree_thresholds.unaligned_flat<float>();
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
const auto split_features = candidate_split_features.tensor<int32, 2>();
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
const int32 num_splits = candidate_split_features.shape().dim_size(1);
for (int i = start; i < end; ++i) {
const Tensor point = input_data.Slice(i, i + 1);
int node_index = 0;
results[i].splits_initialized = false;
while (true) {
results[i].node_indices.push_back(node_index);
int32 left_child = tree(node_index, CHILDREN_INDEX);
if (left_child == LEAF_NODE) {
const int32 accumulator = node_map(node_index);
results[i].leaf_accumulator = accumulator;
// If the leaf is not fertile or is not yet initialized, we don't
// count it in the candidate/total split per-class-weights because
// it won't have any candidate splits yet.
if (accumulator >= 0 &&
IsAllInitialized(candidate_split_features.Slice(
accumulator, accumulator + 1))) {
results[i].splits_initialized = true;
for (int split = 0; split < num_splits; split++) {
if (!DecideNode(point, split_features(accumulator, split),
split_thresholds(accumulator, split))) {
results[i].split_adds.push_back(split);
}
}
}
break;
} else if (left_child == FREE_NODE) {
LOG(ERROR) << "Reached a free node, not good.";
results[i].node_indices.push_back(FREE_NODE);
break;
}
node_index =
left_child + DecideNode(point, tree(node_index, FEATURE_INDEX),
thresholds(node_index));
}
}
}
REGISTER_OP("CountExtremelyRandomStats")
.Attr("num_classes: int32")
.Input("input_data: float")
@ -176,7 +244,31 @@ class CountExtremelyRandomStats : public OpKernel {
"candidate_split_features and candidate_split_thresholds should be "
"the same shape."));
const int32 num_splits = candidate_split_features.shape().dim_size(1);
// Evaluate input data in parallel.
const int64 num_data = input_data.shape().dim_size(0);
std::unique_ptr<InputDataResult[]> results(new InputDataResult[num_data]);
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
if (num_threads <= 1) {
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
node_to_accumulator, candidate_split_features,
candidate_split_thresholds, results.get(), 0, num_data);
} else {
auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds,
&node_to_accumulator, &candidate_split_features,
&candidate_split_thresholds, &num_data,
&results](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
node_to_accumulator, candidate_split_features,
candidate_split_thresholds, results.get(), start, end);
};
Shard(num_threads, worker_threads->workers, num_data, 100, work);
}
// Set output tensors.
const auto labels = input_labels.unaligned_flat<int32>();
// node pcw delta
Tensor* output_node_pcw_delta = nullptr;
@ -196,58 +288,28 @@ class CountExtremelyRandomStats : public OpKernel {
&output_leaves));
auto out_leaves = output_leaves->unaligned_flat<int32>();
const auto tree = tree_tensor.tensor<int32, 2>();
const auto thresholds = tree_thresholds.unaligned_flat<float>();
const auto labels = input_labels.unaligned_flat<int32>();
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
const auto split_features = candidate_split_features.tensor<int32, 2>();
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
const int32 num_data = input_data.shape().dim_size(0);
// <accumulator, class> -> count delta
std::unordered_map<pair<int32, int32>, int32, PairIntHash> total_delta;
// <accumulator, split, class> -> count delta
std::unordered_map<tuple<int32, int32, int32>,
int32, TupleIntHash> split_delta;
for (int i = 0; i < num_data; i++) {
const Tensor point = input_data.Slice(i, i+1);
int node_index = 0;
while (true) {
const int32 label = labels(i);
++out_node(node_index, label);
int32 left_child = tree(node_index, CHILDREN_INDEX);
if (left_child == LEAF_NODE) {
out_leaves(i) = node_index;
const int32 accumulator = node_map(node_index);
// If the leaf is not fertile or is not yet initialized, we don't
// count it in the candidate/total split per-class-weights because
// it won't have any candidate splits yet.
if (accumulator >= 0 &&
IsAllInitialized(
candidate_split_features.Slice(accumulator,
accumulator + 1))) {
++total_delta[std::make_pair(accumulator, label)];
for (int split = 0; split < num_splits; split++) {
if (!DecideNode(point, split_features(accumulator, split),
split_thresholds(accumulator, split))) {
++split_delta[make_tuple(accumulator, split, label)];
}
}
}
break;
} else if (left_child == FREE_NODE) {
LOG(ERROR) << "Reached a free node, not good.";
out_leaves(i) = FREE_NODE;
break;
for (int32 i = 0; i < num_data; ++i) {
const int32 label = labels(i);
const int32 accumulator = results[i].leaf_accumulator;
for (const int32 node : results[i].node_indices) {
++out_node(node, label);
}
out_leaves(i) = results[i].node_indices.back();
if (accumulator >= 0 && results[i].splits_initialized) {
++total_delta[make_pair(accumulator, label)];
for (const int32 split : results[i].split_adds) {
++split_delta[make_tuple(accumulator, split, label)];
}
node_index = left_child +
DecideNode(point, tree(node_index, FEATURE_INDEX),
thresholds(node_index));
}
}
// candidate splits pcw indices
// candidate splits pcw indices
Tensor* output_candidate_pcw_indices = nullptr;
TensorShape candidate_pcw_shape;
candidate_pcw_shape.AddDim(split_delta.size());

View File

@ -94,7 +94,7 @@ class SampleInputs : public OpKernel {
"split_sampling_random_seed", &split_sampling_random_seed_));
// Set up the random number generator.
if (split_sampling_random_seed_ == 0) {
uint64 time_seed = static_cast<uint64>(std::time(NULL));
uint64 time_seed = static_cast<uint64>(std::clock());
single_rand_ = std::unique_ptr<random::PhiloxRandom>(
new random::PhiloxRandom(time_seed));
} else {

View File

@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow # pylint: disable=unused-import
import tensorflow as tf
from tensorflow.contrib.tensor_forest.python.ops import training_ops
@ -47,6 +47,29 @@ class CountExtremelyRandomStatsTest(test_util.TensorFlowTestCase):
self.tree_thresholds, self.node_map,
self.split_features, self.split_thresholds, num_classes=4))
self.assertAllEqual(
[[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]],
pcw_node.eval())
self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval())
self.assertAllEqual([1.], pcw_splits_delta.eval())
self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval())
self.assertAllEqual([1., 1.], pcw_totals_delta.eval())
self.assertAllEqual([1, 1, 2, 2], leaves.eval())
def testThreaded(self):
with self.test_session(
config=tf.ConfigProto(intra_op_parallelism_threads=2)):
(pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
pcw_totals_delta,
leaves) = (self.ops.count_extremely_random_stats(self.input_data,
self.input_labels,
self.tree,
self.tree_thresholds,
self.node_map,
self.split_features,
self.split_thresholds,
num_classes=4))
self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
[0., 0., 1., 1.]],
pcw_node.eval())

View File

@ -25,12 +25,6 @@ import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('inference_library_base_dir', '',
'Directory to look for inference library file.')
INFERENCE_OPS_FILE = '_inference_ops.so'
_inference_ops = None
@ -55,12 +49,12 @@ def TreePredictions(op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
def Load():
def Load(library_base_dir=''):
"""Load the inference ops library and return the loaded module."""
with _ops_lock:
global _inference_ops
if not _inference_ops:
data_files_path = os.path.join(FLAGS.inference_library_base_dir,
data_files_path = os.path.join(library_base_dir,
tf.resource_loader.get_data_files_path())
tf.logging.info('data path: %s', data_files_path)
_inference_ops = tf.load_op_library(os.path.join(

View File

@ -25,11 +25,6 @@ import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('training_library_base_dir', '',
'Directory to look for inference library file.')
TRAINING_OPS_FILE = '_training_ops.so'
@ -102,12 +97,12 @@ def _UpdateFertileSlotsShape(unused_op):
# there's not yet any guarantee that the shared object exists.
# In which case, "import tensorflow" will always crash, even for users that
# never use contrib.
def Load():
def Load(library_base_dir=''):
"""Load training ops library and return the loaded module."""
with _ops_lock:
global _training_ops
if not _training_ops:
data_files_path = os.path.join(FLAGS.training_library_base_dir,
data_files_path = os.path.join(library_base_dir,
tf.resource_loader.get_data_files_path())
tf.logging.info('data path: %s', data_files_path)
_training_ops = tf.load_op_library(os.path.join(

View File

@ -25,26 +25,6 @@ from tensorflow.contrib.tensor_forest.python.ops import inference_ops
from tensorflow.contrib.tensor_forest.python.ops import training_ops
flags = tf.app.flags
FLAGS = flags.FLAGS
# Default parameter values. These are all only used if the corresponding
# parameter is not specified when constructing the ForestHParams.
flags.DEFINE_integer('num_trees', 100, 'Number of trees in forest')
flags.DEFINE_integer('max_nodes', 10000, 'Maxmimum number of tree nodes.')
flags.DEFINE_float(
'samples_to_decide', 25.0,
'Only decide on a split, or only fully use a leaf, after this many '
'training samples have been seen.')
flags.DEFINE_float('bagging_fraction', 1.0,
'Use this fraction of the input, randomly chosen, to train '
'each tree in the forest.')
flags.DEFINE_integer(
'num_splits_to_consider', 0,
'If non-zero, consider this many candidates for a splitting '
'rule at a fertile node.')
# If tree[i][0] equals this value, then i is a leaf node.
LEAF_NODE = -1
@ -64,7 +44,20 @@ LEAF_NODE = -1
class ForestHParams(object):
"""A base class for holding hyperparameters and calculating good defaults."""
def __init__(self, **kwargs):
def __init__(self, num_trees=100, max_nodes=10000, bagging_fraction=1.0,
samples_to_decide=25, max_depth=0, num_splits_to_consider=0,
max_fertile_nodes=0, split_after_samples=0,
valid_leaf_threshold=0, **kwargs):
self.num_trees = num_trees
self.max_nodes = max_nodes
self.bagging_fraction = bagging_fraction
self.samples_to_decide = samples_to_decide
self.max_depth = max_depth
self.num_splits_to_consider = num_splits_to_consider
self.max_fertile_nodes = max_fertile_nodes
self.split_after_samples = split_after_samples
self.valid_leaf_threshold = valid_leaf_threshold
for name, value in kwargs.items():
setattr(self, name, value)
@ -76,24 +69,21 @@ class ForestHParams(object):
# Fail fast if num_classes isn't set.
_ = getattr(self, 'num_classes')
self.bagging_fraction = getattr(self, 'bagging_fraction',
FLAGS.bagging_fraction)
self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees)
self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes)
self.training_library_base_dir = getattr(
self, 'training_library_base_dir', '')
self.inference_library_base_dir = getattr(
self, 'inference_library_base_dir', '')
# Allow each tree to be unbalanced by up to a factor of 2.
self.max_depth = getattr(self, 'max_depth',
int(2 * math.ceil(math.log(self.max_nodes, 2))))
self.max_depth = (self.max_depth or
int(2 * math.ceil(math.log(self.max_nodes, 2))))
# The Random Forest literature recommends sqrt(# features) for
# classification problems, and p/3 for regression problems.
# TODO(thomaswc): Consider capping this for large number of features.
self.num_splits_to_consider = getattr(self, 'num_splits_to_consider',
FLAGS.num_splits_to_consider)
if not self.num_splits_to_consider:
self.num_splits_to_consider = max(10, int(
math.ceil(math.sqrt(self.num_features))))
self.num_splits_to_consider = (
self.num_splits_to_consider or
max(10, int(math.ceil(math.sqrt(self.num_features)))))
# max_fertile_nodes doesn't effect performance, only training speed.
# We therefore set it primarily based upon space considerations.
@ -103,7 +93,7 @@ class ForestHParams(object):
num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider))
# But always use at least 1000 accumulate slots.
num_fertile = max(num_fertile, 1000)
self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile)
self.max_fertile_nodes = self.max_fertile_nodes or num_fertile
# But it also never needs to be larger than the number of leaves,
# which is max_nodes / 2.
self.max_fertile_nodes = min(self.max_fertile_nodes,
@ -111,14 +101,11 @@ class ForestHParams(object):
# split_after_samples and valid_leaf_threshold should be about the same.
# Therefore, if either is set, use it to set the other. Otherwise, fall
# back on FLAGS.samples_to_decide.
samples_to_decide = (
getattr(self, 'split_after_samples',
getattr(self, 'valid_leaf_threshold', FLAGS.samples_to_decide)))
self.split_after_samples = getattr(self, 'split_after_samples',
samples_to_decide)
self.valid_leaf_threshold = getattr(self, 'valid_leaf_threshold',
samples_to_decide)
# back on samples_to_decide.
samples_to_decide = self.split_after_samples or self.samples_to_decide
self.split_after_samples = self.split_after_samples or samples_to_decide
self.valid_leaf_threshold = self.valid_leaf_threshold or samples_to_decide
# We have num_splits_to_consider slots to fill, and we want to spend
# approximately split_after_samples samples initializing them.
@ -249,9 +236,12 @@ class RandomForestGraphs(object):
tf.logging.info(self.params.__dict__)
self.variables = variables or ForestTrainingVariables(
self.params, device_assigner=self.device_assigner)
self.trees = [RandomTreeGraphs(self.variables[i], self.params,
training_ops.Load(), inference_ops.Load())
for i in range(self.params.num_trees)]
self.trees = [
RandomTreeGraphs(
self.variables[i], self.params,
training_ops.Load(self.params.training_library_base_dir),
inference_ops.Load(self.params.inference_library_base_dir))
for i in range(self.params.num_trees)]
def training_graph(self, input_data, input_labels):
"""Constructs a TF graph for training a random forest.