Improve the bucket generation in Boosted Trees to avoid returning more than requested buckets.

PiperOrigin-RevId: 312371738
Change-Id: I7f241c839f52d679ad4ceb82c161018e9b944fa3
This commit is contained in:
A. Unique TensorFlower 2020-05-19 15:58:25 -07:00 committed by TensorFlower Gardener
parent d894109fe1
commit 91da977a03
2 changed files with 48 additions and 8 deletions

View File

@ -16,6 +16,7 @@
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
#include <cstring>
#include <list>
#include <vector>
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
@ -250,10 +251,37 @@ class WeightedQuantilesSummary {
float compression_eps = ApproximationError() + (1.0 / num_boundaries);
compressed_summary.Compress(num_boundaries, compression_eps);
// Remove the least important boundaries by the gap removing them would
// create.
std::list<int64> boundaries_to_keep;
for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) {
boundaries_to_keep.push_back(i);
}
while (boundaries_to_keep.size() > num_boundaries) {
std::list<int64>::iterator min_element = boundaries_to_keep.end();
auto prev = boundaries_to_keep.begin();
auto curr = prev;
++curr;
auto next = curr;
++next;
WeightType min_weight = TotalWeight();
for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) {
WeightType new_weight =
compressed_summary.entries_[*next].PrevMaxRank() -
compressed_summary.entries_[*prev].NextMinRank();
if (new_weight < min_weight) {
min_element = curr;
min_weight = new_weight;
}
}
boundaries_to_keep.erase(min_element);
}
// Return boundaries.
output.reserve(compressed_summary.entries_.size());
for (const auto& entry : compressed_summary.entries_) {
output.push_back(entry.value);
output.reserve(boundaries_to_keep.size());
for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end();
++itr) {
output.push_back(compressed_summary.entries_[*itr].value);
}
return output;
}

View File

@ -82,7 +82,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
self.eps = 0.01
self.max_elements = 1 << 16
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
self.num_quantiles = constant_op.constant(4, dtype=dtypes.int64)
def testBasicQuantileBucketsSingleResource(self):
with self.cached_session() as sess:
@ -183,7 +183,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run()
@ -202,7 +205,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.session(graph=ops.Graph()) as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver()
save.restore(sess, save_path)
buckets = accumulator.get_bucket_boundaries()
@ -215,7 +221,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver()
resources.initialize_resources(resources.shared_resources()).run()
@ -233,7 +242,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
with self.session(graph=ops.Graph()) as sess:
accumulator = boosted_trees_ops.QuantileAccumulator(
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
num_streams=2,
num_quantiles=self.num_quantiles,
epsilon=self.eps,
name="q0")
save = saver.Saver()
save.restore(sess, save_path)
buckets = accumulator.get_bucket_boundaries()