diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h index 5690c3a6014..a22af7ab71e 100644 --- a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h @@ -16,6 +16,7 @@ #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ #include +#include #include #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h" @@ -250,10 +251,37 @@ class WeightedQuantilesSummary { float compression_eps = ApproximationError() + (1.0 / num_boundaries); compressed_summary.Compress(num_boundaries, compression_eps); + // Remove the least important boundaries by the gap removing them would + // create. + std::list boundaries_to_keep; + for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) { + boundaries_to_keep.push_back(i); + } + while (boundaries_to_keep.size() > num_boundaries) { + std::list::iterator min_element = boundaries_to_keep.end(); + auto prev = boundaries_to_keep.begin(); + auto curr = prev; + ++curr; + auto next = curr; + ++next; + WeightType min_weight = TotalWeight(); + for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) { + WeightType new_weight = + compressed_summary.entries_[*next].PrevMaxRank() - + compressed_summary.entries_[*prev].NextMinRank(); + if (new_weight < min_weight) { + min_element = curr; + min_weight = new_weight; + } + } + boundaries_to_keep.erase(min_element); + } + // Return boundaries. - output.reserve(compressed_summary.entries_.size()); - for (const auto& entry : compressed_summary.entries_) { - output.push_back(entry.value); + output.reserve(boundaries_to_keep.size()); + for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end(); + ++itr) { + output.push_back(compressed_summary.entries_[*itr].value); } return output; } diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py index fb44c33d602..7c3a382c955 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py @@ -82,7 +82,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): self.eps = 0.01 self.max_elements = 1 << 16 - self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64) + self.num_quantiles = constant_op.constant(4, dtype=dtypes.int64) def testBasicQuantileBucketsSingleResource(self): with self.cached_session() as sess: @@ -183,7 +183,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() resources.initialize_resources(resources.shared_resources()).run() @@ -202,7 +205,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.session(graph=ops.Graph()) as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() save.restore(sess, save_path) buckets = accumulator.get_bucket_boundaries() @@ -215,7 +221,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() resources.initialize_resources(resources.shared_resources()).run() @@ -233,7 +242,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase): with self.session(graph=ops.Graph()) as sess: accumulator = boosted_trees_ops.QuantileAccumulator( - num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") + num_streams=2, + num_quantiles=self.num_quantiles, + epsilon=self.eps, + name="q0") save = saver.Saver() save.restore(sess, save_path) buckets = accumulator.get_bucket_boundaries()