Improve the bucket generation in Boosted Trees to avoid returning more than requested buckets.

PiperOrigin-RevId: 312371738
Change-Id: I7f241c839f52d679ad4ceb82c161018e9b944fa3
This commit is contained in:
A. Unique TensorFlower 2020-05-19 15:58:25 -07:00 committed by TensorFlower Gardener
parent d894109fe1
commit 91da977a03
2 changed files with 48 additions and 8 deletions

View File

@ -16,6 +16,7 @@
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
#include <cstring> #include <cstring>
#include <list>
#include <vector> #include <vector>
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h" #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
@ -250,10 +251,37 @@ class WeightedQuantilesSummary {
float compression_eps = ApproximationError() + (1.0 / num_boundaries); float compression_eps = ApproximationError() + (1.0 / num_boundaries);
compressed_summary.Compress(num_boundaries, compression_eps); compressed_summary.Compress(num_boundaries, compression_eps);
// Remove the least important boundaries by the gap removing them would
// create.
std::list<int64> boundaries_to_keep;
for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) {
boundaries_to_keep.push_back(i);
}
while (boundaries_to_keep.size() > num_boundaries) {
std::list<int64>::iterator min_element = boundaries_to_keep.end();
auto prev = boundaries_to_keep.begin();
auto curr = prev;
++curr;
auto next = curr;
++next;
WeightType min_weight = TotalWeight();
for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) {
WeightType new_weight =
compressed_summary.entries_[*next].PrevMaxRank() -
compressed_summary.entries_[*prev].NextMinRank();
if (new_weight < min_weight) {
min_element = curr;
min_weight = new_weight;
}
}
boundaries_to_keep.erase(min_element);
}
// Return boundaries. // Return boundaries.
output.reserve(compressed_summary.entries_.size()); output.reserve(boundaries_to_keep.size());
for (const auto& entry : compressed_summary.entries_) { for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end();
output.push_back(entry.value); ++itr) {
output.push_back(compressed_summary.entries_[*itr].value);
} }
return output; return output;
} }

View File

@ -82,7 +82,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
self.eps = 0.01 self.eps = 0.01
self.max_elements = 1 << 16 self.max_elements = 1 << 16
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64) self.num_quantiles = constant_op.constant(4, dtype=dtypes.int64)
def testBasicQuantileBucketsSingleResource(self): def testBasicQuantileBucketsSingleResource(self):
with self.cached_session() as sess: with self.cached_session() as sess:
@ -183,7 +183,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = boosted_trees_ops.QuantileAccumulator( accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver() save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run() resources.initialize_resources(resources.shared_resources()).run()
@ -202,7 +205,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
accumulator = boosted_trees_ops.QuantileAccumulator( accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver() save = saver.Saver()
save.restore(sess, save_path) save.restore(sess, save_path)
buckets = accumulator.get_bucket_boundaries() buckets = accumulator.get_bucket_boundaries()
@ -215,7 +221,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess: with self.cached_session() as sess:
accumulator = boosted_trees_ops.QuantileAccumulator( accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver() save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run() resources.initialize_resources(resources.shared_resources()).run()
@ -233,7 +242,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
accumulator = boosted_trees_ops.QuantileAccumulator( accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0") num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver() save = saver.Saver()
save.restore(sess, save_path) save.restore(sess, save_path)
buckets = accumulator.get_bucket_boundaries() buckets = accumulator.get_bucket_boundaries()