First iteration of oblivious tree split handling for dense features.

PiperOrigin-RevId: 208705535
This commit is contained in:
A. Unique TensorFlower 2018-08-14 13:46:42 -07:00 committed by TensorFlower Gardener
parent 0c98648e9a
commit af827be63a
7 changed files with 376 additions and 37 deletions

View File

@ -34,7 +34,9 @@
namespace tensorflow {
using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearnerConfig_MultiClassStrategy;
using boosted_trees::learner::ObliviousSplitInfo;
using boosted_trees::learner::SplitInfo;
using boosted_trees::learner::stochastic::GradientStats;
using boosted_trees::learner::stochastic::NodeStats;
@ -158,6 +160,11 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
const Tensor* weak_learner_type_t;
OP_REQUIRES_OK(context,
context->input("weak_learner_type", &weak_learner_type_t));
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
@ -188,20 +195,59 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
Tensor* gains_t = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output("gains", TensorShape({num_elements}),
&gains_t));
// For a normal tree, we output a split per partition. For an oblivious
// tree, we output one split for all partitions of the layer
int32 size_output = num_elements;
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
num_elements > 0) {
size_output = 1;
}
Tensor* gains_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
"gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
"split_infos", TensorShape({num_elements}),
&output_splits_t));
OP_REQUIRES_OK(context, context->allocate_output("split_infos",
TensorShape({size_output}),
&output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
if (num_elements == 0) {
return;
}
SplitBuilderState state(context);
switch (weak_learner_type) {
case LearnerConfig::NORMAL_DECISION_TREE: {
ComputeNormalDecisionTree(
&state, normalizer_ratio, num_elements, partition_boundaries,
bucket_boundaries, partition_ids, bucket_ids, gradients_t,
hessians_t, &output_partition_ids, &gains, &output_splits);
break;
}
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
ComputeObliviousDecisionTree(
&state, normalizer_ratio, num_elements, partition_boundaries,
bucket_boundaries, partition_ids, bucket_ids, gradients_t,
hessians_t, &output_partition_ids, &gains, &output_splits);
break;
}
}
}
private:
void ComputeNormalDecisionTree(
SplitBuilderState* state, const float normalizer_ratio,
const int num_elements, const std::vector<int32>& partition_boundaries,
const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
const Tensor* gradients_t, const Tensor* hessians_t,
tensorflow::TTypes<int32>::Vec* output_partition_ids,
tensorflow::TTypes<float>::Vec* gains,
tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[root_idx];
@ -213,7 +259,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_bucket_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@ -223,10 +269,10 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
GradientStats g(*gradients_t, *hessians_t, bucket_idx);
g *= normalizer_ratio;
left_gradient_stats += g;
NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@ -237,21 +283,125 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* dense_split =
split_info.mutable_split_node()->mutable_dense_float_binary_split();
dense_split->set_feature_column(state.feature_column_group_id());
dense_split->set_feature_column(state->feature_column_group_id());
dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
state.FillLeaf(best_left_node_stats, left_child);
state.FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
best_gain - root_stats.gain - state.tree_complexity_regularization();
output_partition_ids(root_idx) = partition_ids(start_index);
state->FillLeaf(best_left_node_stats, left_child);
state->FillLeaf(best_right_node_stats, right_child);
split_info.SerializeToString(&(*output_splits)(root_idx));
(*gains)(root_idx) =
best_gain - root_stats.gain - state->tree_complexity_regularization();
(*output_partition_ids)(root_idx) = partition_ids(start_index);
}
}
void ComputeObliviousDecisionTree(
SplitBuilderState* state, const float normalizer_ratio,
const int num_elements, const std::vector<int32>& partition_boundaries,
const tensorflow::TTypes<float>::ConstVec& bucket_boundaries,
const tensorflow::TTypes<int32>::ConstVec& partition_ids,
const tensorflow::TTypes<int64>::ConstMatrix& bucket_ids,
const Tensor* gradients_t, const Tensor* hessians_t,
tensorflow::TTypes<int32>::Vec* output_partition_ids,
tensorflow::TTypes<float>::Vec* gains,
tensorflow::TTypes<string>::Vec* output_splits) {
// Holds the root stats per each node to be split.
std::vector<GradientStats> current_layer_stats;
current_layer_stats.reserve(num_elements);
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
const int start_index = partition_boundaries[root_idx];
const int end_index = partition_boundaries[root_idx + 1];
GradientStats root_gradient_stats;
for (int64 bucket_idx = start_index; bucket_idx < end_index;
++bucket_idx) {
root_gradient_stats +=
GradientStats(*gradients_t, *hessians_t, bucket_idx);
}
root_gradient_stats *= normalizer_ratio;
current_layer_stats.push_back(root_gradient_stats);
}
float best_gain = std::numeric_limits<float>::lowest();
int64 best_bucket_idx = 0;
std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
int64 current_bucket_id = 0;
int64 last_bucket_id = -1;
// Indexes offsets for each of the partitions that can be used to access
// gradients of a partition for a current bucket we consider.
std::vector<int> current_layer_offsets(num_elements, 0);
std::vector<GradientStats> left_gradient_stats(num_elements);
// The idea is to try every bucket id in increasing order. In each iteration
// we calculate the gain of the layer using the current bucket id as split
// value, and we also obtain the following bucket id to try.
while (current_bucket_id > last_bucket_id) {
last_bucket_id = current_bucket_id;
int64 next_bucket_id = -1;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
int idx =
current_layer_offsets[root_idx] + partition_boundaries[root_idx];
const int end_index = partition_boundaries[root_idx + 1];
if (idx < end_index && bucket_ids(idx, 0) == current_bucket_id) {
GradientStats g(*gradients_t, *hessians_t, idx);
g *= normalizer_ratio;
left_gradient_stats[root_idx] += g;
current_layer_offsets[root_idx]++;
idx++;
}
if (idx < end_index &&
(bucket_ids(idx, 0) < next_bucket_id || next_bucket_id == -1)) {
next_bucket_id = bucket_ids(idx, 0);
}
}
float gain_of_split = 0.0;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
GradientStats right_gradient_stats =
current_layer_stats[root_idx] - left_gradient_stats[root_idx];
NodeStats left_stat =
state->ComputeNodeStats(left_gradient_stats[root_idx]);
NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
gain_of_split += left_stat.gain + right_stat.gain;
current_left_node_stats[root_idx] = left_stat;
current_right_node_stats[root_idx] = right_stat;
}
if (gain_of_split > best_gain) {
best_gain = gain_of_split;
best_left_node_stats = current_left_node_stats;
best_right_node_stats = current_right_node_stats;
}
current_bucket_id = next_bucket_id;
}
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
}
best_gain -= num_elements * state->tree_complexity_regularization();
ObliviousSplitInfo oblivious_split_info;
auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
->mutable_dense_float_binary_split();
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
oblivious_dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
(*gains)(0) = best_gain;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
auto* left_children = oblivious_split_info.add_children_leaves();
auto* right_children = oblivious_split_info.add_children_leaves();
state->FillLeaf(best_left_node_stats[root_idx], left_children);
state->FillLeaf(best_right_node_stats[root_idx], right_children);
const int start_index = partition_boundaries[root_idx];
(*output_partition_ids)(root_idx) = partition_ids(start_index);
}
oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
};
REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU),
BuildDenseInequalitySplitsOp);

View File

@ -64,6 +64,7 @@ from __future__ import print_function
import re
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(DenseSplitHandler, self).__init__(
@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy=multiclass_strategy,
loss_uses_sum_reduction=loss_uses_sum_reduction)
self._dense_float_column = dense_float_column
self._weak_learner_type = weak_learner_type
# Register dense_make_stats_update function as an Op to the graph.
g = ops.get_default_graph()
dense_make_stats_update.add_to_graph(g)
@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler):
next_stamp_token, self._multiclass_strategy, class_id,
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
self._min_node_weight, self._loss_uses_sum_reduction))
self._min_node_weight, self._loss_uses_sum_reduction,
self._weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
def _make_dense_split(
quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
l1_regularization, l2_regularization, tree_complexity_regularization,
min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
stamp_token, next_stamp_token, multiclass_strategy,
class_id, feature_column_id, l1_regularization,
l2_regularization, tree_complexity_regularization,
min_node_weight, is_multi_dimentional,
loss_uses_sum_reduction, weak_learner_type):
"""Function that builds splits for a dense feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
@ -327,7 +332,8 @@ def _make_dense_split(
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
multiclass_strategy=multiclass_strategy))
multiclass_strategy=multiclass_strategy,
weak_learner_type=weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
@ -507,7 +513,40 @@ def _make_sparse_split(
return are_splits_ready, partition_ids, gains, split_infos
def _specialize_make_split(func, is_multi_dimentional):
def _specialize_make_split_dense(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
@function.Defun(
dtypes.resource,
dtypes.resource,
dtypes.int64,
dtypes.int64,
dtypes.int32,
dtypes.int32,
dtypes.int32,
dtypes.float32,
dtypes.float32,
dtypes.float32,
dtypes.float32,
dtypes.bool,
dtypes.int32,
noinline=True)
def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
l1_regularization, l2_regularization, tree_complexity_regularization,
min_node_weight, loss_uses_sum_reduction, weak_learner_type):
"""Function that builds splits for a sparse feature column."""
return func(quantile_accumulator_handle, stats_accumulator_handle,
stamp_token, next_stamp_token, multiclass_strategy, class_id,
feature_column_id, l1_regularization, l2_regularization,
tree_complexity_regularization, min_node_weight,
is_multi_dimentional, loss_uses_sum_reduction,
weak_learner_type)
return f
def _specialize_make_split_sparse(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
@function.Defun(
@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional):
return f
make_dense_split_scalar = _specialize_make_split(_make_dense_split,
is_multi_dimentional=False)
make_dense_split_tensor = _specialize_make_split(_make_dense_split,
is_multi_dimentional=True)
make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
is_multi_dimentional=False)
make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
is_multi_dimentional=True)
make_dense_split_scalar = _specialize_make_split_dense(
_make_dense_split, is_multi_dimentional=False)
make_dense_split_tensor = _specialize_make_split_dense(
_make_dense_split, is_multi_dimentional=True)
make_sparse_split_scalar = _specialize_make_split_sparse(
_make_sparse_split, is_multi_dimentional=False)
make_sparse_split_tensor = _specialize_make_split_sparse(
_make_sparse_split, is_multi_dimentional=True)
@function.Defun(

View File

@ -182,6 +182,133 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testObliviousFeatureSplitGeneration(self):
with self.test_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 2 |
# i1 | (-0.5, 0.07) | 0 | 2 |
# i2 | (1.2, 0.2) | 0 | 0 |
# i3 | (4.0, 0.13) | 1 | 1 |
dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
class_id = -1
gradient_shape = tensor_shape.scalar()
hessian_shape = tensor_shape.scalar()
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
l2_regularization=1.,
tree_complexity_regularization=0.,
min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
dense_float_column=dense_column,
init_stamp_token=0,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
resources.initialize_resources(resources.shared_resources()).run()
empty_gradients, empty_hessians = get_empty_tensors(
gradient_shape, hessian_shape)
example_weights = array_ops.ones([4, 1], dtypes.float32)
update_1 = split_handler.update_stats_sync(
0,
partition_ids,
gradients,
hessians,
empty_gradients,
empty_hessians,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_1]):
are_splits_ready = split_handler.make_splits(
np.int64(0), np.int64(1), class_id)[0]
with ops.control_dependencies([are_splits_ready]):
update_2 = split_handler.update_stats_sync(
1,
partition_ids,
gradients,
hessians,
empty_gradients,
empty_hessians,
example_weights,
is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
split_handler.make_splits(np.int64(1), np.int64(2), class_id))
are_splits_ready, are_splits_ready2, partitions, gains, splits = (
sess.run([
are_splits_ready, are_splits_ready2, partitions, gains, splits
]))
# During the first iteration, inequality split handlers are not going to
# have any splits. Make sure that we return not_ready in that case.
self.assertFalse(are_splits_ready)
self.assertTrue(are_splits_ready2)
self.assertAllEqual([0, 1], partitions)
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
oblivious_split_info.ParseFromString(splits[0])
split_node = oblivious_split_info.split_node.dense_float_binary_split
self.assertAllClose(0.3, split_node.threshold, 0.00001)
self.assertEqual(0, split_node.feature_column)
# Check the split on partition 0.
# -(1.2 - 0.1) / (0.2 + 1)
expected_left_weight_0 = -0.9166666666666666
# expected_left_weight_0 * -(1.2 - 0.1)
expected_left_gain_0 = 1.008333333333333
# (-0.5 + 0.2 + 0.1) / (0.19 + 1)
expected_right_weight_0 = 0.1680672
# expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
expected_right_gain_0 = 0.033613445378151252
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
expected_bias_gain_0 = 0.46043165467625896
left_child = oblivious_split_info.children_leaves[0].vector
right_child = oblivious_split_info.children_leaves[1].vector
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
# Check the split on partition 1.
expected_left_weight_1 = 0
expected_left_gain_1 = 0
# -(4 - 0.1) / (0.13 + 1)
expected_right_weight_1 = -3.4513274336283186
# expected_right_weight_1 * -(4 - 0.1)
expected_right_gain_1 = 13.460176991150442
# (-4 + 0.1) ** 2 / (0.13 + 1)
expected_bias_gain_1 = 13.460176991150442
left_child = oblivious_split_info.children_leaves[2].vector
right_child = oblivious_split_info.children_leaves[3].vector
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
# The layer gain is the sum of the gains of each partition
layer_gain = (
expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
self.assertAllClose(layer_gain, gains[0], 0.00001)
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
with self.test_session() as sess:
# The data looks like the following:

View File

@ -36,6 +36,7 @@ REGISTER_OP("BuildDenseInequalitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
.Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@ -84,6 +85,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
weak_learner_type: A scalar, specifying the weak learner type to use.
See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.

View File

@ -108,6 +108,11 @@ message LearnerConfig {
DIAGONAL_HESSIAN = 3;
}
enum WeakLearnerType {
NORMAL_DECISION_TREE = 0;
OBLIVIOUS_DECISION_TREE = 1;
}
// Number of classes.
uint32 num_classes = 1;
@ -141,4 +146,7 @@ message LearnerConfig {
// If you want to average the ensembles (for regularization), provide the
// config below.
AveragingConfig averaging_config = 11;
// By default we use NORMAL_DECISION_TREE as weak learner.
WeakLearnerType weak_learner_type = 12;
}

View File

@ -17,3 +17,10 @@ message SplitInfo {
// Right Leaf node.
tensorflow.boosted_trees.trees.Leaf right_child = 3;
}
message ObliviousSplitInfo {
// The split node with the feature_column and threshold defined.
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
// The new leaves of the tree.
repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
}

View File

@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
# .assertEmpty doesn't exist on ubuntu-contrib
self.assertEqual(0, len(partitions))