First iteration of oblivious tree split handling for dense features.
PiperOrigin-RevId: 208705535
This commit is contained in:
parent
0c98648e9a
commit
af827be63a
@ -34,7 +34,9 @@
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using boosted_trees::learner::LearnerConfig;
|
||||
using boosted_trees::learner::LearnerConfig_MultiClassStrategy;
|
||||
using boosted_trees::learner::ObliviousSplitInfo;
|
||||
using boosted_trees::learner::SplitInfo;
|
||||
using boosted_trees::learner::stochastic::GradientStats;
|
||||
using boosted_trees::learner::stochastic::NodeStats;
|
||||
@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
|
||||
const Tensor* hessians_t;
|
||||
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
|
||||
|
||||
const Tensor* weak_learner_type_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("weak_learner_type", &weak_learner_type_t));
|
||||
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
|
||||
|
||||
// Find the number of unique partitions before we allocate the output.
|
||||
std::vector<int32> partition_boundaries;
|
||||
partition_boundaries.push_back(0);
|
||||
@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
|
||||
tensorflow::TTypes<int32>::Vec output_partition_ids =
|
||||
output_partition_ids_t->vec<int32>();
|
||||
|
||||
Tensor* gains_t = nullptr;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output("gains", TensorShape({num_elements}),
|
||||
&gains_t));
|
||||
// For a normal tree, we output a split per partition. For an oblivious
|
||||
// tree, we output one split for all partitions of the layer
|
||||
int32 size_output = num_elements;
|
||||
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
|
||||
num_elements > 0) {
|
||||
size_output = 1;
|
||||
}
|
||||
|
||||
Tensor* gains_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
"gains", TensorShape({size_output}), &gains_t));
|
||||
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
|
||||
|
||||
Tensor* output_splits_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
"split_infos", TensorShape({num_elements}),
|
||||
&output_splits_t));
|
||||
OP_REQUIRES_OK(context, context->allocate_output("split_infos",
|
||||
TensorShape({size_output}),
|
||||
&output_splits_t));
|
||||
tensorflow::TTypes<string>::Vec output_splits =
|
||||
output_splits_t->vec<string>();
|
||||
|
||||
if (num_elements == 0) {
|
||||
return;
|
||||
}
|
||||
SplitBuilderState state(context);
|
||||
switch (weak_learner_type) {
|
||||
case LearnerConfig::NORMAL_DECISION_TREE: {
|
||||
ComputeNormalDecisionTree(
|
||||
&state, normalizer_ratio, num_elements, partition_boundaries,
|
||||
bucket_boundaries, partition_ids, bucket_ids, gradients_t,
|
||||
hessians_t, &output_partition_ids, &gains, &output_splits);
|
||||
break;
|
||||
}
|
||||
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
|
||||
ComputeObliviousDecisionTree(
|
||||
&state, normalizer_ratio, num_elements, partition_boundaries,
|
||||
bucket_boundaries, partition_ids, bucket_ids, gradients_t,
|
||||
hessians_t, &output_partition_ids, &gains, &output_splits);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void ComputeNormalDecisionTree(
|
||||
SplitBuilderState* state, const float normalizer_ratio,
|
||||
const int num_elements, const std::vector<int32>& partition_boundaries,
|
||||
const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
|
||||
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
|
||||
const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
|
||||
const Tensor* gradients_t, const Tensor* hessians_t,
|
||||
tensorflow::TTypes<int32>::Vec* output_partition_ids,
|
||||
tensorflow::TTypes<float>::Vec* gains,
|
||||
tensorflow::TTypes<string>::Vec* output_splits) {
|
||||
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
|
||||
float best_gain = std::numeric_limits<float>::lowest();
|
||||
int start_index = partition_boundaries[root_idx];
|
||||
@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
|
||||
GradientStats(*gradients_t, *hessians_t, bucket_idx);
|
||||
}
|
||||
root_gradient_stats *= normalizer_ratio;
|
||||
NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
|
||||
NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
|
||||
int32 best_bucket_idx = 0;
|
||||
NodeStats best_right_node_stats(0);
|
||||
NodeStats best_left_node_stats(0);
|
||||
@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
|
||||
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
|
||||
g *= normalizer_ratio;
|
||||
left_gradient_stats += g;
|
||||
NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
|
||||
NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
|
||||
GradientStats right_gradient_stats =
|
||||
root_gradient_stats - left_gradient_stats;
|
||||
NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
|
||||
NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
|
||||
if (left_stats.gain + right_stats.gain > best_gain) {
|
||||
best_gain = left_stats.gain + right_stats.gain;
|
||||
best_left_node_stats = left_stats;
|
||||
@ -237,21 +283,125 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
|
||||
SplitInfo split_info;
|
||||
auto* dense_split =
|
||||
split_info.mutable_split_node()->mutable_dense_float_binary_split();
|
||||
dense_split->set_feature_column(state.feature_column_group_id());
|
||||
dense_split->set_feature_column(state->feature_column_group_id());
|
||||
dense_split->set_threshold(
|
||||
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
|
||||
|
||||
auto* left_child = split_info.mutable_left_child();
|
||||
auto* right_child = split_info.mutable_right_child();
|
||||
|
||||
state.FillLeaf(best_left_node_stats, left_child);
|
||||
state.FillLeaf(best_right_node_stats, right_child);
|
||||
split_info.SerializeToString(&output_splits(root_idx));
|
||||
gains(root_idx) =
|
||||
best_gain - root_stats.gain - state.tree_complexity_regularization();
|
||||
output_partition_ids(root_idx) = partition_ids(start_index);
|
||||
state->FillLeaf(best_left_node_stats, left_child);
|
||||
state->FillLeaf(best_right_node_stats, right_child);
|
||||
split_info.SerializeToString(&(*output_splits)(root_idx));
|
||||
(*gains)(root_idx) =
|
||||
best_gain - root_stats.gain - state->tree_complexity_regularization();
|
||||
(*output_partition_ids)(root_idx) = partition_ids(start_index);
|
||||
}
|
||||
}
|
||||
void ComputeObliviousDecisionTree(
|
||||
SplitBuilderState* state, const float normalizer_ratio,
|
||||
const int num_elements, const std::vector<int32>& partition_boundaries,
|
||||
const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
|
||||
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
|
||||
const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
|
||||
const Tensor* gradients_t, const Tensor* hessians_t,
|
||||
tensorflow::TTypes<int32>::Vec* output_partition_ids,
|
||||
tensorflow::TTypes<float>::Vec* gains,
|
||||
tensorflow::TTypes<string>::Vec* output_splits) {
|
||||
// Holds the root stats per each node to be split.
|
||||
std::vector<GradientStats> current_layer_stats;
|
||||
current_layer_stats.reserve(num_elements);
|
||||
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
|
||||
const int start_index = partition_boundaries[root_idx];
|
||||
const int end_index = partition_boundaries[root_idx + 1];
|
||||
GradientStats root_gradient_stats;
|
||||
for (int64 bucket_idx = start_index; bucket_idx < end_index;
|
||||
++bucket_idx) {
|
||||
root_gradient_stats +=
|
||||
GradientStats(*gradients_t, *hessians_t, bucket_idx);
|
||||
}
|
||||
root_gradient_stats *= normalizer_ratio;
|
||||
current_layer_stats.push_back(root_gradient_stats);
|
||||
}
|
||||
|
||||
float best_gain = std::numeric_limits<float>::lowest();
|
||||
int64 best_bucket_idx = 0;
|
||||
std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
|
||||
std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
|
||||
std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
|
||||
std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
|
||||
int64 current_bucket_id = 0;
|
||||
int64 last_bucket_id = -1;
|
||||
// Indexes offsets for each of the partitions that can be used to access
|
||||
// gradients of a partition for a current bucket we consider.
|
||||
std::vector<int> current_layer_offsets(num_elements, 0);
|
||||
std::vector<GradientStats> left_gradient_stats(num_elements);
|
||||
// The idea is to try every bucket id in increasing order. In each iteration
|
||||
// we calculate the gain of the layer using the current bucket id as split
|
||||
// value, and we also obtain the following bucket id to try.
|
||||
while (current_bucket_id > last_bucket_id) {
|
||||
last_bucket_id = current_bucket_id;
|
||||
int64 next_bucket_id = -1;
|
||||
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
|
||||
int idx =
|
||||
current_layer_offsets[root_idx] + partition_boundaries[root_idx];
|
||||
const int end_index = partition_boundaries[root_idx + 1];
|
||||
if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) {
|
||||
GradientStats g(*gradients_t, *hessians_t, idx);
|
||||
g *= normalizer_ratio;
|
||||
left_gradient_stats[root_idx] += g;
|
||||
current_layer_offsets[root_idx]++;
|
||||
idx++;
|
||||
}
|
||||
if (idx < end_index &&
|
||||
(bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) {
|
||||
next_bucket_id = bucket_ids(idx, 0);
|
||||
}
|
||||
}
|
||||
float gain_of_split = 0.0;
|
||||
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
|
||||
GradientStats right_gradient_stats =
|
||||
current_layer_stats[root_idx] - left_gradient_stats[root_idx];
|
||||
NodeStats left_stat =
|
||||
state->ComputeNodeStats(left_gradient_stats[root_idx]);
|
||||
NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
|
||||
gain_of_split += left_stat.gain + right_stat.gain;
|
||||
current_left_node_stats[root_idx] = left_stat;
|
||||
current_right_node_stats[root_idx] = right_stat;
|
||||
}
|
||||
if (gain_of_split > best_gain) {
|
||||
best_gain = gain_of_split;
|
||||
best_left_node_stats = current_left_node_stats;
|
||||
best_right_node_stats = current_right_node_stats;
|
||||
}
|
||||
current_bucket_id = next_bucket_id;
|
||||
}
|
||||
|
||||
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
|
||||
best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
|
||||
}
|
||||
best_gain -= num_elements * state->tree_complexity_regularization();
|
||||
|
||||
ObliviousSplitInfo oblivious_split_info;
|
||||
auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
|
||||
->mutable_dense_float_binary_split();
|
||||
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
|
||||
oblivious_dense_split->set_threshold(
|
||||
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
|
||||
(*gains)(0) = best_gain;
|
||||
|
||||
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
|
||||
auto* left_children = oblivious_split_info.add_children_leaves();
|
||||
auto* right_children = oblivious_split_info.add_children_leaves();
|
||||
|
||||
state->FillLeaf(best_left_node_stats[root_idx], left_children);
|
||||
state->FillLeaf(best_right_node_stats[root_idx], right_children);
|
||||
|
||||
const int start_index = partition_boundaries[root_idx];
|
||||
(*output_partition_ids)(root_idx) = partition_ids(start_index);
|
||||
}
|
||||
oblivious_split_info.SerializeToString(&(*output_splits)(0));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
|
||||
BuildDenseInequalitySplitsOp);
|
||||
|
@ -64,6 +64,7 @@ from __future__ import print_function
|
||||
import re
|
||||
|
||||
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
|
||||
from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
|
||||
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
|
||||
@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler):
|
||||
multiclass_strategy,
|
||||
init_stamp_token=0,
|
||||
loss_uses_sum_reduction=False,
|
||||
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
|
||||
name=None):
|
||||
"""Initialize the internal state for this split handler.
|
||||
|
||||
@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler):
|
||||
stamped objects.
|
||||
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
|
||||
SUM or MEAN reduction was used for the loss.
|
||||
weak_learner_type: Specifies the type of weak learner to use.
|
||||
name: An optional handler name.
|
||||
"""
|
||||
super(DenseSplitHandler, self).__init__(
|
||||
@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler):
|
||||
multiclass_strategy=multiclass_strategy,
|
||||
loss_uses_sum_reduction=loss_uses_sum_reduction)
|
||||
self._dense_float_column = dense_float_column
|
||||
self._weak_learner_type = weak_learner_type
|
||||
# Register dense_make_stats_update function as an Op to the graph.
|
||||
g = ops.get_default_graph()
|
||||
dense_make_stats_update.add_to_graph(g)
|
||||
@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler):
|
||||
next_stamp_token, self._multiclass_strategy, class_id,
|
||||
self._feature_column_group_id, self._l1_regularization,
|
||||
self._l2_regularization, self._tree_complexity_regularization,
|
||||
self._min_node_weight, self._loss_uses_sum_reduction))
|
||||
|
||||
self._min_node_weight, self._loss_uses_sum_reduction,
|
||||
self._weak_learner_type))
|
||||
return are_splits_ready, partition_ids, gains, split_infos
|
||||
|
||||
|
||||
def _make_dense_split(
|
||||
quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
|
||||
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
|
||||
l1_regularization, l2_regularization, tree_complexity_regularization,
|
||||
min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
|
||||
def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
|
||||
stamp_token, next_stamp_token, multiclass_strategy,
|
||||
class_id, feature_column_id, l1_regularization,
|
||||
l2_regularization, tree_complexity_regularization,
|
||||
min_node_weight, is_multi_dimentional,
|
||||
loss_uses_sum_reduction, weak_learner_type):
|
||||
"""Function that builds splits for a dense feature column."""
|
||||
# Get the bucket boundaries
|
||||
are_splits_ready, buckets = (
|
||||
@ -327,7 +332,8 @@ def _make_dense_split(
|
||||
l2_regularization=l2_regularization,
|
||||
tree_complexity_regularization=tree_complexity_regularization,
|
||||
min_node_weight=min_node_weight,
|
||||
multiclass_strategy=multiclass_strategy))
|
||||
multiclass_strategy=multiclass_strategy,
|
||||
weak_learner_type=weak_learner_type))
|
||||
return are_splits_ready, partition_ids, gains, split_infos
|
||||
|
||||
|
||||
@ -507,7 +513,40 @@ def _make_sparse_split(
|
||||
return are_splits_ready, partition_ids, gains, split_infos
|
||||
|
||||
|
||||
def _specialize_make_split(func, is_multi_dimentional):
|
||||
def _specialize_make_split_dense(func, is_multi_dimentional):
|
||||
"""Builds a specialized version of the function."""
|
||||
|
||||
@function.Defun(
|
||||
dtypes.resource,
|
||||
dtypes.resource,
|
||||
dtypes.int64,
|
||||
dtypes.int64,
|
||||
dtypes.int32,
|
||||
dtypes.int32,
|
||||
dtypes.int32,
|
||||
dtypes.float32,
|
||||
dtypes.float32,
|
||||
dtypes.float32,
|
||||
dtypes.float32,
|
||||
dtypes.bool,
|
||||
dtypes.int32,
|
||||
noinline=True)
|
||||
def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
|
||||
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
|
||||
l1_regularization, l2_regularization, tree_complexity_regularization,
|
||||
min_node_weight, loss_uses_sum_reduction, weak_learner_type):
|
||||
"""Function that builds splits for a sparse feature column."""
|
||||
return func(quantile_accumulator_handle, stats_accumulator_handle,
|
||||
stamp_token, next_stamp_token, multiclass_strategy, class_id,
|
||||
feature_column_id, l1_regularization, l2_regularization,
|
||||
tree_complexity_regularization, min_node_weight,
|
||||
is_multi_dimentional, loss_uses_sum_reduction,
|
||||
weak_learner_type)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _specialize_make_split_sparse(func, is_multi_dimentional):
|
||||
"""Builds a specialized version of the function."""
|
||||
|
||||
@function.Defun(
|
||||
@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional):
|
||||
|
||||
return f
|
||||
|
||||
make_dense_split_scalar = _specialize_make_split(_make_dense_split,
|
||||
is_multi_dimentional=False)
|
||||
make_dense_split_tensor = _specialize_make_split(_make_dense_split,
|
||||
is_multi_dimentional=True)
|
||||
|
||||
make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
|
||||
is_multi_dimentional=False)
|
||||
make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
|
||||
is_multi_dimentional=True)
|
||||
make_dense_split_scalar = _specialize_make_split_dense(
|
||||
_make_dense_split, is_multi_dimentional=False)
|
||||
|
||||
make_dense_split_tensor = _specialize_make_split_dense(
|
||||
_make_dense_split, is_multi_dimentional=True)
|
||||
|
||||
make_sparse_split_scalar = _specialize_make_split_sparse(
|
||||
_make_sparse_split, is_multi_dimentional=False)
|
||||
make_sparse_split_tensor = _specialize_make_split_sparse(
|
||||
_make_sparse_split, is_multi_dimentional=True)
|
||||
|
||||
|
||||
@function.Defun(
|
||||
|
@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
|
||||
|
||||
self.assertAllClose(0.52, split_node.threshold, 0.00001)
|
||||
|
||||
def testObliviousFeatureSplitGeneration(self):
|
||||
with self.test_session() as sess:
|
||||
# The data looks like the following:
|
||||
# Example | Gradients | Partition | Dense Quantile |
|
||||
# i0 | (0.2, 0.12) | 0 | 2 |
|
||||
# i1 | (-0.5, 0.07) | 0 | 2 |
|
||||
# i2 | (1.2, 0.2) | 0 | 0 |
|
||||
# i3 | (4.0, 0.13) | 1 | 1 |
|
||||
dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
|
||||
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
|
||||
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
|
||||
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
|
||||
class_id = -1
|
||||
|
||||
gradient_shape = tensor_shape.scalar()
|
||||
hessian_shape = tensor_shape.scalar()
|
||||
split_handler = ordinal_split_handler.DenseSplitHandler(
|
||||
l1_regularization=0.1,
|
||||
l2_regularization=1.,
|
||||
tree_complexity_regularization=0.,
|
||||
min_node_weight=0.,
|
||||
epsilon=0.001,
|
||||
num_quantiles=10,
|
||||
feature_column_group_id=0,
|
||||
dense_float_column=dense_column,
|
||||
init_stamp_token=0,
|
||||
gradient_shape=gradient_shape,
|
||||
hessian_shape=hessian_shape,
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
|
||||
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
|
||||
empty_gradients, empty_hessians = get_empty_tensors(
|
||||
gradient_shape, hessian_shape)
|
||||
example_weights = array_ops.ones([4, 1], dtypes.float32)
|
||||
|
||||
update_1 = split_handler.update_stats_sync(
|
||||
0,
|
||||
partition_ids,
|
||||
gradients,
|
||||
hessians,
|
||||
empty_gradients,
|
||||
empty_hessians,
|
||||
example_weights,
|
||||
is_active=array_ops.constant([True, True]))
|
||||
with ops.control_dependencies([update_1]):
|
||||
are_splits_ready = split_handler.make_splits(
|
||||
np.int64(0), np.int64(1), class_id)[0]
|
||||
|
||||
with ops.control_dependencies([are_splits_ready]):
|
||||
update_2 = split_handler.update_stats_sync(
|
||||
1,
|
||||
partition_ids,
|
||||
gradients,
|
||||
hessians,
|
||||
empty_gradients,
|
||||
empty_hessians,
|
||||
example_weights,
|
||||
is_active=array_ops.constant([True, True]))
|
||||
with ops.control_dependencies([update_2]):
|
||||
are_splits_ready2, partitions, gains, splits = (
|
||||
split_handler.make_splits(np.int64(1), np.int64(2), class_id))
|
||||
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
|
||||
sess.run([
|
||||
are_splits_ready, are_splits_ready2, partitions, gains, splits
|
||||
]))
|
||||
|
||||
# During the first iteration, inequality split handlers are not going to
|
||||
# have any splits. Make sure that we return not_ready in that case.
|
||||
self.assertFalse(are_splits_ready)
|
||||
self.assertTrue(are_splits_ready2)
|
||||
|
||||
self.assertAllEqual([0, 1], partitions)
|
||||
|
||||
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
|
||||
oblivious_split_info.ParseFromString(splits[0])
|
||||
split_node = oblivious_split_info.split_node.dense_float_binary_split
|
||||
|
||||
self.assertAllClose(0.3, split_node.threshold, 0.00001)
|
||||
self.assertEqual(0, split_node.feature_column)
|
||||
|
||||
# Check the split on partition 0.
|
||||
# -(1.2 - 0.1) / (0.2 + 1)
|
||||
expected_left_weight_0 = -0.9166666666666666
|
||||
|
||||
# expected_left_weight_0 * -(1.2 - 0.1)
|
||||
expected_left_gain_0 = 1.008333333333333
|
||||
|
||||
# (-0.5 + 0.2 + 0.1) / (0.19 + 1)
|
||||
expected_right_weight_0 = 0.1680672
|
||||
|
||||
# expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
|
||||
expected_right_gain_0 = 0.033613445378151252
|
||||
|
||||
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
|
||||
expected_bias_gain_0 = 0.46043165467625896
|
||||
|
||||
left_child = oblivious_split_info.children_leaves[0].vector
|
||||
right_child = oblivious_split_info.children_leaves[1].vector
|
||||
|
||||
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
|
||||
|
||||
self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
|
||||
|
||||
# Check the split on partition 1.
|
||||
expected_left_weight_1 = 0
|
||||
expected_left_gain_1 = 0
|
||||
# -(4 - 0.1) / (0.13 + 1)
|
||||
expected_right_weight_1 = -3.4513274336283186
|
||||
# expected_right_weight_1 * -(4 - 0.1)
|
||||
expected_right_gain_1 = 13.460176991150442
|
||||
# (-4 + 0.1) ** 2 / (0.13 + 1)
|
||||
expected_bias_gain_1 = 13.460176991150442
|
||||
|
||||
left_child = oblivious_split_info.children_leaves[2].vector
|
||||
right_child = oblivious_split_info.children_leaves[3].vector
|
||||
|
||||
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
|
||||
|
||||
self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
|
||||
|
||||
# The layer gain is the sum of the gains of each partition
|
||||
layer_gain = (
|
||||
expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
|
||||
expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
|
||||
self.assertAllClose(layer_gain, gains[0], 0.00001)
|
||||
|
||||
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
|
||||
with self.test_session() as sess:
|
||||
# The data looks like the following:
|
||||
|
@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits")
|
||||
.Input("tree_complexity_regularization: float")
|
||||
.Input("min_node_weight: float")
|
||||
.Input("multiclass_strategy: int32")
|
||||
.Input("weak_learner_type: int32")
|
||||
.Output("output_partition_ids: int32")
|
||||
.Output("gains: float32")
|
||||
.Output("split_infos: string")
|
||||
@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
|
||||
be considered.
|
||||
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
|
||||
See LearnerConfig.MultiClassStrategy for valid values.
|
||||
weak_learner_type: A scalar, specifying the weak learner type to use.
|
||||
See LearnerConfig.WeakLearnerType for valid values.
|
||||
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
|
||||
for.
|
||||
gains: A rank 1 tensor, for the computed gain for the created splits.
|
||||
|
@ -108,6 +108,11 @@ message LearnerConfig {
|
||||
DIAGONAL_HESSIAN = 3;
|
||||
}
|
||||
|
||||
enum WeakLearnerType {
|
||||
NORMAL_DECISION_TREE = 0;
|
||||
OBLIVIOUS_DECISION_TREE = 1;
|
||||
}
|
||||
|
||||
// Number of classes.
|
||||
uint32 num_classes = 1;
|
||||
|
||||
@ -141,4 +146,7 @@ message LearnerConfig {
|
||||
// If you want to average the ensembles (for regularization), provide the
|
||||
// config below.
|
||||
AveragingConfig averaging_config = 11;
|
||||
|
||||
// By default we use NORMAL_DECISION_TREE as weak learner.
|
||||
WeakLearnerType weak_learner_type = 12;
|
||||
}
|
||||
|
@ -17,3 +17,10 @@ message SplitInfo {
|
||||
// Right Leaf node.
|
||||
tensorflow.boosted_trees.trees.Leaf right_child = 3;
|
||||
}
|
||||
|
||||
message ObliviousSplitInfo {
|
||||
// The split node with the feature_column and threshold defined.
|
||||
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
|
||||
// The new leaves of the tree.
|
||||
repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
|
||||
}
|
||||
|
@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
|
||||
min_node_weight=0,
|
||||
class_id=-1,
|
||||
feature_column_group_id=0,
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
|
||||
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
|
||||
partitions, gains, splits = sess.run([partitions, gains, splits])
|
||||
self.assertAllEqual([0, 1], partitions)
|
||||
|
||||
@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
|
||||
min_node_weight=0,
|
||||
class_id=-1,
|
||||
feature_column_group_id=0,
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
|
||||
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
|
||||
partitions, gains, splits = sess.run([partitions, gains, splits])
|
||||
self.assertAllEqual([0, 1], partitions)
|
||||
|
||||
@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
|
||||
min_node_weight=0,
|
||||
class_id=-1,
|
||||
feature_column_group_id=0,
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
|
||||
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
|
||||
partitions, gains, splits = sess.run([partitions, gains, splits])
|
||||
# .assertEmpty doesn't exist on ubuntu-contrib
|
||||
self.assertEqual(0, len(partitions))
|
||||
|
Loading…
x
Reference in New Issue
Block a user