Fix default direction to left when almost no sparsity for a sparse inequality split.
PiperOrigin-RevId: 196026149
This commit is contained in:
parent
5d47c53adb
commit
42ee0ef7bc
@ -422,6 +422,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
|
||||
GradientStats(*gradients_t, *hessians_t, bucket_idx);
|
||||
}
|
||||
present_gradient_stats *= normalizer_ratio;
|
||||
GradientStats not_present =
|
||||
root_gradient_stats - present_gradient_stats;
|
||||
// If there was (almost) no sparsity, fix the default direction to LEFT.
|
||||
bool fixed_default_direction = not_present.IsAlmostZero();
|
||||
|
||||
GradientStats left_gradient_stats;
|
||||
for (int64 element_idx = start_index; element_idx < end_index;
|
||||
@ -441,6 +445,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
|
||||
// backward pass gradients.
|
||||
GradientStats right_gradient_stats =
|
||||
present_gradient_stats - left_gradient_stats;
|
||||
|
||||
{
|
||||
NodeStats left_stats_default_left =
|
||||
ComputeNodeStats(root_gradient_stats - right_gradient_stats);
|
||||
@ -457,7 +462,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
|
||||
best_dimension_idx = dimension_id;
|
||||
}
|
||||
}
|
||||
{
|
||||
// Consider calculating the default direction only when there were
|
||||
// enough missing examples.
|
||||
if (!fixed_default_direction) {
|
||||
NodeStats left_stats_default_right =
|
||||
ComputeNodeStats(left_gradient_stats);
|
||||
NodeStats right_stats_default_right =
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
from tensorflow.contrib.boosted_trees.proto import learner_pb2
|
||||
from tensorflow.contrib.boosted_trees.proto import split_info_pb2
|
||||
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
|
||||
@ -399,6 +401,65 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
self.assertAllClose(0.6, split_node.split.threshold)
|
||||
|
||||
def testMakeSparseSplitDefaultDirectionIsStable(self):
|
||||
"""Tests default direction is stable when no sparsity."""
|
||||
random.seed(1123)
|
||||
for _ in range(50):
|
||||
with self.test_session() as sess:
|
||||
grad = random.random()
|
||||
hessian = random.random()
|
||||
# The data looks like the following (divide by the num of steps 2).
|
||||
# Gradients | Partition | bucket ID |
|
||||
# (grad, hessian) | 0 | -1 |
|
||||
# And then 100 buckets of
|
||||
# (grad/100, hessian/100), so there is no sparsity.
|
||||
n_buckets = 100
|
||||
|
||||
# 1 for the overall sum, and 100 buckets.
|
||||
partition_ids = array_ops.constant(
|
||||
[0] * (n_buckets + 1), dtype=dtypes.int32)
|
||||
# We have only 1 dimension in our sparse feature column.
|
||||
|
||||
bucket_ids = [-1] + [n for n in range(100)]
|
||||
bucket_ids = array_ops.constant(bucket_ids, dtype=dtypes.int64)
|
||||
dimension_ids = array_ops.constant(
|
||||
[0] * (n_buckets + 1), dtype=dtypes.int64)
|
||||
bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1)
|
||||
|
||||
gradients = [grad] + [grad / n_buckets] * n_buckets
|
||||
gradients = array_ops.constant(gradients)
|
||||
hessians = [hessian] + [hessian / n_buckets] * n_buckets
|
||||
hessians = array_ops.constant(hessians)
|
||||
|
||||
boundaries = [x * 1 for x in range(n_buckets + 1)]
|
||||
bucket_boundaries = array_ops.constant(boundaries, dtype=dtypes.float32)
|
||||
|
||||
partitions, gains, splits = (
|
||||
split_handler_ops.build_sparse_inequality_splits(
|
||||
num_minibatches=2,
|
||||
partition_ids=partition_ids,
|
||||
bucket_ids=bucket_ids,
|
||||
gradients=gradients,
|
||||
hessians=hessians,
|
||||
bucket_boundaries=bucket_boundaries,
|
||||
l1_regularization=0,
|
||||
l2_regularization=2,
|
||||
tree_complexity_regularization=0,
|
||||
min_node_weight=0,
|
||||
feature_column_group_id=0,
|
||||
bias_feature_id=-1,
|
||||
class_id=-1,
|
||||
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
|
||||
partitions, gains, splits = (sess.run([partitions, gains, splits]))
|
||||
self.assertAllEqual([0], partitions)
|
||||
self.assertEqual(1, len(splits))
|
||||
|
||||
split_info = split_info_pb2.SplitInfo()
|
||||
split_info.ParseFromString(splits[0])
|
||||
self.assertTrue(
|
||||
split_info.split_node.HasField(
|
||||
'sparse_float_binary_split_default_left'))
|
||||
|
||||
def testMakeMulticlassSparseSplit(self):
|
||||
"""Tests split handler op."""
|
||||
with self.test_session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user