Fixes and enhancements for contrib/tensorforest:
- remove flags (pushed to internal client) - thread-parallel execution for CountExtremeleyRandomStats op. - critical time-seeding fix. Change: 120129936
This commit is contained in:
parent
449ecb561f
commit
c2d9cb1d08
tensorflow/contrib/tensor_forest
@ -18,6 +18,7 @@
|
||||
// only op that involves tree traversal, and is constructed so that it can
|
||||
// be run in parallel on separate batches of data.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
|
||||
|
||||
@ -25,10 +26,12 @@
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using std::get;
|
||||
using std::make_pair;
|
||||
using std::make_tuple;
|
||||
using std::pair;
|
||||
using std::tuple;
|
||||
@ -42,6 +45,71 @@ using tensorforest::DecideNode;
|
||||
using tensorforest::Initialize;
|
||||
using tensorforest::IsAllInitialized;
|
||||
|
||||
// A data structure to store the results of parallel tree traversal.
|
||||
struct InputDataResult {
|
||||
// A list of each node that was visited.
|
||||
std::vector<int32> node_indices;
|
||||
// The accumulator of the leaf that a data point ended up at, or -1 if none.
|
||||
int32 leaf_accumulator;
|
||||
// The left-branch taken candidate splits.
|
||||
std::vector<int32> split_adds;
|
||||
// If the candidate splits for the leaf that a data point arrived at
|
||||
// were initialized or not, which determines if we add this to total
|
||||
// pcw counts or not.
|
||||
bool splits_initialized;
|
||||
};
|
||||
|
||||
void Evaluate(const Tensor& input_data, const Tensor& input_labels,
|
||||
const Tensor& tree_tensor, const Tensor& tree_thresholds,
|
||||
const Tensor& node_to_accumulator,
|
||||
const Tensor& candidate_split_features,
|
||||
const Tensor& candidate_split_thresholds,
|
||||
InputDataResult* results, int64 start, int64 end) {
|
||||
const auto tree = tree_tensor.tensor<int32, 2>();
|
||||
const auto thresholds = tree_thresholds.unaligned_flat<float>();
|
||||
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
|
||||
const auto split_features = candidate_split_features.tensor<int32, 2>();
|
||||
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
|
||||
|
||||
const int32 num_splits = candidate_split_features.shape().dim_size(1);
|
||||
|
||||
for (int i = start; i < end; ++i) {
|
||||
const Tensor point = input_data.Slice(i, i + 1);
|
||||
int node_index = 0;
|
||||
results[i].splits_initialized = false;
|
||||
while (true) {
|
||||
results[i].node_indices.push_back(node_index);
|
||||
int32 left_child = tree(node_index, CHILDREN_INDEX);
|
||||
if (left_child == LEAF_NODE) {
|
||||
const int32 accumulator = node_map(node_index);
|
||||
results[i].leaf_accumulator = accumulator;
|
||||
// If the leaf is not fertile or is not yet initialized, we don't
|
||||
// count it in the candidate/total split per-class-weights because
|
||||
// it won't have any candidate splits yet.
|
||||
if (accumulator >= 0 &&
|
||||
IsAllInitialized(candidate_split_features.Slice(
|
||||
accumulator, accumulator + 1))) {
|
||||
results[i].splits_initialized = true;
|
||||
for (int split = 0; split < num_splits; split++) {
|
||||
if (!DecideNode(point, split_features(accumulator, split),
|
||||
split_thresholds(accumulator, split))) {
|
||||
results[i].split_adds.push_back(split);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else if (left_child == FREE_NODE) {
|
||||
LOG(ERROR) << "Reached a free node, not good.";
|
||||
results[i].node_indices.push_back(FREE_NODE);
|
||||
break;
|
||||
}
|
||||
node_index =
|
||||
left_child + DecideNode(point, tree(node_index, FEATURE_INDEX),
|
||||
thresholds(node_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_OP("CountExtremelyRandomStats")
|
||||
.Attr("num_classes: int32")
|
||||
.Input("input_data: float")
|
||||
@ -176,7 +244,31 @@ class CountExtremelyRandomStats : public OpKernel {
|
||||
"candidate_split_features and candidate_split_thresholds should be "
|
||||
"the same shape."));
|
||||
|
||||
const int32 num_splits = candidate_split_features.shape().dim_size(1);
|
||||
// Evaluate input data in parallel.
|
||||
const int64 num_data = input_data.shape().dim_size(0);
|
||||
std::unique_ptr<InputDataResult[]> results(new InputDataResult[num_data]);
|
||||
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
|
||||
int num_threads = worker_threads->num_threads;
|
||||
if (num_threads <= 1) {
|
||||
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
|
||||
node_to_accumulator, candidate_split_features,
|
||||
candidate_split_thresholds, results.get(), 0, num_data);
|
||||
} else {
|
||||
auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds,
|
||||
&node_to_accumulator, &candidate_split_features,
|
||||
&candidate_split_thresholds, &num_data,
|
||||
&results](int64 start, int64 end) {
|
||||
CHECK(start <= end);
|
||||
CHECK(end <= num_data);
|
||||
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
|
||||
node_to_accumulator, candidate_split_features,
|
||||
candidate_split_thresholds, results.get(), start, end);
|
||||
};
|
||||
Shard(num_threads, worker_threads->workers, num_data, 100, work);
|
||||
}
|
||||
|
||||
// Set output tensors.
|
||||
const auto labels = input_labels.unaligned_flat<int32>();
|
||||
|
||||
// node pcw delta
|
||||
Tensor* output_node_pcw_delta = nullptr;
|
||||
@ -196,58 +288,28 @@ class CountExtremelyRandomStats : public OpKernel {
|
||||
&output_leaves));
|
||||
auto out_leaves = output_leaves->unaligned_flat<int32>();
|
||||
|
||||
const auto tree = tree_tensor.tensor<int32, 2>();
|
||||
const auto thresholds = tree_thresholds.unaligned_flat<float>();
|
||||
const auto labels = input_labels.unaligned_flat<int32>();
|
||||
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
|
||||
const auto split_features = candidate_split_features.tensor<int32, 2>();
|
||||
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
|
||||
|
||||
const int32 num_data = input_data.shape().dim_size(0);
|
||||
|
||||
// <accumulator, class> -> count delta
|
||||
std::unordered_map<pair<int32, int32>, int32, PairIntHash> total_delta;
|
||||
// <accumulator, split, class> -> count delta
|
||||
std::unordered_map<tuple<int32, int32, int32>,
|
||||
int32, TupleIntHash> split_delta;
|
||||
for (int i = 0; i < num_data; i++) {
|
||||
const Tensor point = input_data.Slice(i, i+1);
|
||||
int node_index = 0;
|
||||
while (true) {
|
||||
const int32 label = labels(i);
|
||||
++out_node(node_index, label);
|
||||
int32 left_child = tree(node_index, CHILDREN_INDEX);
|
||||
if (left_child == LEAF_NODE) {
|
||||
out_leaves(i) = node_index;
|
||||
const int32 accumulator = node_map(node_index);
|
||||
// If the leaf is not fertile or is not yet initialized, we don't
|
||||
// count it in the candidate/total split per-class-weights because
|
||||
// it won't have any candidate splits yet.
|
||||
if (accumulator >= 0 &&
|
||||
IsAllInitialized(
|
||||
candidate_split_features.Slice(accumulator,
|
||||
accumulator + 1))) {
|
||||
++total_delta[std::make_pair(accumulator, label)];
|
||||
for (int split = 0; split < num_splits; split++) {
|
||||
if (!DecideNode(point, split_features(accumulator, split),
|
||||
split_thresholds(accumulator, split))) {
|
||||
++split_delta[make_tuple(accumulator, split, label)];
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else if (left_child == FREE_NODE) {
|
||||
LOG(ERROR) << "Reached a free node, not good.";
|
||||
out_leaves(i) = FREE_NODE;
|
||||
break;
|
||||
|
||||
for (int32 i = 0; i < num_data; ++i) {
|
||||
const int32 label = labels(i);
|
||||
const int32 accumulator = results[i].leaf_accumulator;
|
||||
for (const int32 node : results[i].node_indices) {
|
||||
++out_node(node, label);
|
||||
}
|
||||
out_leaves(i) = results[i].node_indices.back();
|
||||
if (accumulator >= 0 && results[i].splits_initialized) {
|
||||
++total_delta[make_pair(accumulator, label)];
|
||||
for (const int32 split : results[i].split_adds) {
|
||||
++split_delta[make_tuple(accumulator, split, label)];
|
||||
}
|
||||
node_index = left_child +
|
||||
DecideNode(point, tree(node_index, FEATURE_INDEX),
|
||||
thresholds(node_index));
|
||||
}
|
||||
}
|
||||
|
||||
// candidate splits pcw indices
|
||||
// candidate splits pcw indices
|
||||
Tensor* output_candidate_pcw_indices = nullptr;
|
||||
TensorShape candidate_pcw_shape;
|
||||
candidate_pcw_shape.AddDim(split_delta.size());
|
||||
|
@ -94,7 +94,7 @@ class SampleInputs : public OpKernel {
|
||||
"split_sampling_random_seed", &split_sampling_random_seed_));
|
||||
// Set up the random number generator.
|
||||
if (split_sampling_random_seed_ == 0) {
|
||||
uint64 time_seed = static_cast<uint64>(std::time(NULL));
|
||||
uint64 time_seed = static_cast<uint64>(std::clock());
|
||||
single_rand_ = std::unique_ptr<random::PhiloxRandom>(
|
||||
new random::PhiloxRandom(time_seed));
|
||||
} else {
|
||||
|
@ -17,7 +17,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow # pylint: disable=unused-import
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.tensor_forest.python.ops import training_ops
|
||||
|
||||
@ -47,6 +47,29 @@ class CountExtremelyRandomStatsTest(test_util.TensorFlowTestCase):
|
||||
self.tree_thresholds, self.node_map,
|
||||
self.split_features, self.split_thresholds, num_classes=4))
|
||||
|
||||
self.assertAllEqual(
|
||||
[[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]],
|
||||
pcw_node.eval())
|
||||
self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval())
|
||||
self.assertAllEqual([1.], pcw_splits_delta.eval())
|
||||
self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval())
|
||||
self.assertAllEqual([1., 1.], pcw_totals_delta.eval())
|
||||
self.assertAllEqual([1, 1, 2, 2], leaves.eval())
|
||||
|
||||
def testThreaded(self):
|
||||
with self.test_session(
|
||||
config=tf.ConfigProto(intra_op_parallelism_threads=2)):
|
||||
(pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
|
||||
pcw_totals_delta,
|
||||
leaves) = (self.ops.count_extremely_random_stats(self.input_data,
|
||||
self.input_labels,
|
||||
self.tree,
|
||||
self.tree_thresholds,
|
||||
self.node_map,
|
||||
self.split_features,
|
||||
self.split_thresholds,
|
||||
num_classes=4))
|
||||
|
||||
self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
|
||||
[0., 0., 1., 1.]],
|
||||
pcw_node.eval())
|
||||
|
@ -25,12 +25,6 @@ import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('inference_library_base_dir', '',
|
||||
'Directory to look for inference library file.')
|
||||
|
||||
INFERENCE_OPS_FILE = '_inference_ops.so'
|
||||
|
||||
_inference_ops = None
|
||||
@ -55,12 +49,12 @@ def TreePredictions(op):
|
||||
# there's not yet any guarantee that the shared object exists.
|
||||
# In which case, "import tensorflow" will always crash, even for users that
|
||||
# never use contrib.
|
||||
def Load():
|
||||
def Load(library_base_dir=''):
|
||||
"""Load the inference ops library and return the loaded module."""
|
||||
with _ops_lock:
|
||||
global _inference_ops
|
||||
if not _inference_ops:
|
||||
data_files_path = os.path.join(FLAGS.inference_library_base_dir,
|
||||
data_files_path = os.path.join(library_base_dir,
|
||||
tf.resource_loader.get_data_files_path())
|
||||
tf.logging.info('data path: %s', data_files_path)
|
||||
_inference_ops = tf.load_op_library(os.path.join(
|
||||
|
@ -25,11 +25,6 @@ import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('training_library_base_dir', '',
|
||||
'Directory to look for inference library file.')
|
||||
|
||||
TRAINING_OPS_FILE = '_training_ops.so'
|
||||
|
||||
@ -102,12 +97,12 @@ def _UpdateFertileSlotsShape(unused_op):
|
||||
# there's not yet any guarantee that the shared object exists.
|
||||
# In which case, "import tensorflow" will always crash, even for users that
|
||||
# never use contrib.
|
||||
def Load():
|
||||
def Load(library_base_dir=''):
|
||||
"""Load training ops library and return the loaded module."""
|
||||
with _ops_lock:
|
||||
global _training_ops
|
||||
if not _training_ops:
|
||||
data_files_path = os.path.join(FLAGS.training_library_base_dir,
|
||||
data_files_path = os.path.join(library_base_dir,
|
||||
tf.resource_loader.get_data_files_path())
|
||||
tf.logging.info('data path: %s', data_files_path)
|
||||
_training_ops = tf.load_op_library(os.path.join(
|
||||
|
@ -25,26 +25,6 @@ from tensorflow.contrib.tensor_forest.python.ops import inference_ops
|
||||
from tensorflow.contrib.tensor_forest.python.ops import training_ops
|
||||
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
# Default parameter values. These are all only used if the corresponding
|
||||
# parameter is not specified when constructing the ForestHParams.
|
||||
flags.DEFINE_integer('num_trees', 100, 'Number of trees in forest')
|
||||
flags.DEFINE_integer('max_nodes', 10000, 'Maxmimum number of tree nodes.')
|
||||
flags.DEFINE_float(
|
||||
'samples_to_decide', 25.0,
|
||||
'Only decide on a split, or only fully use a leaf, after this many '
|
||||
'training samples have been seen.')
|
||||
flags.DEFINE_float('bagging_fraction', 1.0,
|
||||
'Use this fraction of the input, randomly chosen, to train '
|
||||
'each tree in the forest.')
|
||||
flags.DEFINE_integer(
|
||||
'num_splits_to_consider', 0,
|
||||
'If non-zero, consider this many candidates for a splitting '
|
||||
'rule at a fertile node.')
|
||||
|
||||
# If tree[i][0] equals this value, then i is a leaf node.
|
||||
LEAF_NODE = -1
|
||||
|
||||
@ -64,7 +44,20 @@ LEAF_NODE = -1
|
||||
class ForestHParams(object):
|
||||
"""A base class for holding hyperparameters and calculating good defaults."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, num_trees=100, max_nodes=10000, bagging_fraction=1.0,
|
||||
samples_to_decide=25, max_depth=0, num_splits_to_consider=0,
|
||||
max_fertile_nodes=0, split_after_samples=0,
|
||||
valid_leaf_threshold=0, **kwargs):
|
||||
self.num_trees = num_trees
|
||||
self.max_nodes = max_nodes
|
||||
self.bagging_fraction = bagging_fraction
|
||||
self.samples_to_decide = samples_to_decide
|
||||
self.max_depth = max_depth
|
||||
self.num_splits_to_consider = num_splits_to_consider
|
||||
self.max_fertile_nodes = max_fertile_nodes
|
||||
self.split_after_samples = split_after_samples
|
||||
self.valid_leaf_threshold = valid_leaf_threshold
|
||||
|
||||
for name, value in kwargs.items():
|
||||
setattr(self, name, value)
|
||||
|
||||
@ -76,24 +69,21 @@ class ForestHParams(object):
|
||||
# Fail fast if num_classes isn't set.
|
||||
_ = getattr(self, 'num_classes')
|
||||
|
||||
self.bagging_fraction = getattr(self, 'bagging_fraction',
|
||||
FLAGS.bagging_fraction)
|
||||
|
||||
self.num_trees = getattr(self, 'num_trees', FLAGS.num_trees)
|
||||
self.max_nodes = getattr(self, 'max_nodes', FLAGS.max_nodes)
|
||||
self.training_library_base_dir = getattr(
|
||||
self, 'training_library_base_dir', '')
|
||||
self.inference_library_base_dir = getattr(
|
||||
self, 'inference_library_base_dir', '')
|
||||
|
||||
# Allow each tree to be unbalanced by up to a factor of 2.
|
||||
self.max_depth = getattr(self, 'max_depth',
|
||||
int(2 * math.ceil(math.log(self.max_nodes, 2))))
|
||||
self.max_depth = (self.max_depth or
|
||||
int(2 * math.ceil(math.log(self.max_nodes, 2))))
|
||||
|
||||
# The Random Forest literature recommends sqrt(# features) for
|
||||
# classification problems, and p/3 for regression problems.
|
||||
# TODO(thomaswc): Consider capping this for large number of features.
|
||||
self.num_splits_to_consider = getattr(self, 'num_splits_to_consider',
|
||||
FLAGS.num_splits_to_consider)
|
||||
if not self.num_splits_to_consider:
|
||||
self.num_splits_to_consider = max(10, int(
|
||||
math.ceil(math.sqrt(self.num_features))))
|
||||
self.num_splits_to_consider = (
|
||||
self.num_splits_to_consider or
|
||||
max(10, int(math.ceil(math.sqrt(self.num_features)))))
|
||||
|
||||
# max_fertile_nodes doesn't effect performance, only training speed.
|
||||
# We therefore set it primarily based upon space considerations.
|
||||
@ -103,7 +93,7 @@ class ForestHParams(object):
|
||||
num_fertile = int(math.ceil(self.max_nodes / self.num_splits_to_consider))
|
||||
# But always use at least 1000 accumulate slots.
|
||||
num_fertile = max(num_fertile, 1000)
|
||||
self.max_fertile_nodes = getattr(self, 'max_fertile_nodes', num_fertile)
|
||||
self.max_fertile_nodes = self.max_fertile_nodes or num_fertile
|
||||
# But it also never needs to be larger than the number of leaves,
|
||||
# which is max_nodes / 2.
|
||||
self.max_fertile_nodes = min(self.max_fertile_nodes,
|
||||
@ -111,14 +101,11 @@ class ForestHParams(object):
|
||||
|
||||
# split_after_samples and valid_leaf_threshold should be about the same.
|
||||
# Therefore, if either is set, use it to set the other. Otherwise, fall
|
||||
# back on FLAGS.samples_to_decide.
|
||||
samples_to_decide = (
|
||||
getattr(self, 'split_after_samples',
|
||||
getattr(self, 'valid_leaf_threshold', FLAGS.samples_to_decide)))
|
||||
self.split_after_samples = getattr(self, 'split_after_samples',
|
||||
samples_to_decide)
|
||||
self.valid_leaf_threshold = getattr(self, 'valid_leaf_threshold',
|
||||
samples_to_decide)
|
||||
# back on samples_to_decide.
|
||||
samples_to_decide = self.split_after_samples or self.samples_to_decide
|
||||
|
||||
self.split_after_samples = self.split_after_samples or samples_to_decide
|
||||
self.valid_leaf_threshold = self.valid_leaf_threshold or samples_to_decide
|
||||
|
||||
# We have num_splits_to_consider slots to fill, and we want to spend
|
||||
# approximately split_after_samples samples initializing them.
|
||||
@ -249,9 +236,12 @@ class RandomForestGraphs(object):
|
||||
tf.logging.info(self.params.__dict__)
|
||||
self.variables = variables or ForestTrainingVariables(
|
||||
self.params, device_assigner=self.device_assigner)
|
||||
self.trees = [RandomTreeGraphs(self.variables[i], self.params,
|
||||
training_ops.Load(), inference_ops.Load())
|
||||
for i in range(self.params.num_trees)]
|
||||
self.trees = [
|
||||
RandomTreeGraphs(
|
||||
self.variables[i], self.params,
|
||||
training_ops.Load(self.params.training_library_base_dir),
|
||||
inference_ops.Load(self.params.inference_library_base_dir))
|
||||
for i in range(self.params.num_trees)]
|
||||
|
||||
def training_graph(self, input_data, input_labels):
|
||||
"""Constructs a TF graph for training a random forest.
|
||||
|
Loading…
Reference in New Issue
Block a user