Improve the bucket generation in Boosted Trees to avoid returning more than requested buckets.
PiperOrigin-RevId: 312371738 Change-Id: I7f241c839f52d679ad4ceb82c161018e9b944fa3
This commit is contained in:
parent
d894109fe1
commit
91da977a03
|
@ -16,6 +16,7 @@
|
|||
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
|
||||
|
||||
#include <cstring>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
|
||||
|
@ -250,10 +251,37 @@ class WeightedQuantilesSummary {
|
|||
float compression_eps = ApproximationError() + (1.0 / num_boundaries);
|
||||
compressed_summary.Compress(num_boundaries, compression_eps);
|
||||
|
||||
// Remove the least important boundaries by the gap removing them would
|
||||
// create.
|
||||
std::list<int64> boundaries_to_keep;
|
||||
for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) {
|
||||
boundaries_to_keep.push_back(i);
|
||||
}
|
||||
while (boundaries_to_keep.size() > num_boundaries) {
|
||||
std::list<int64>::iterator min_element = boundaries_to_keep.end();
|
||||
auto prev = boundaries_to_keep.begin();
|
||||
auto curr = prev;
|
||||
++curr;
|
||||
auto next = curr;
|
||||
++next;
|
||||
WeightType min_weight = TotalWeight();
|
||||
for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) {
|
||||
WeightType new_weight =
|
||||
compressed_summary.entries_[*next].PrevMaxRank() -
|
||||
compressed_summary.entries_[*prev].NextMinRank();
|
||||
if (new_weight < min_weight) {
|
||||
min_element = curr;
|
||||
min_weight = new_weight;
|
||||
}
|
||||
}
|
||||
boundaries_to_keep.erase(min_element);
|
||||
}
|
||||
|
||||
// Return boundaries.
|
||||
output.reserve(compressed_summary.entries_.size());
|
||||
for (const auto& entry : compressed_summary.entries_) {
|
||||
output.push_back(entry.value);
|
||||
output.reserve(boundaries_to_keep.size());
|
||||
for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end();
|
||||
++itr) {
|
||||
output.push_back(compressed_summary.entries_[*itr].value);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||
|
||||
self.eps = 0.01
|
||||
self.max_elements = 1 << 16
|
||||
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
|
||||
self.num_quantiles = constant_op.constant(4, dtype=dtypes.int64)
|
||||
|
||||
def testBasicQuantileBucketsSingleResource(self):
|
||||
with self.cached_session() as sess:
|
||||
|
@ -183,7 +183,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||
|
||||
with self.cached_session() as sess:
|
||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
||||
num_streams=2,
|
||||
num_quantiles=self.num_quantiles,
|
||||
epsilon=self.eps,
|
||||
name="q0")
|
||||
|
||||
save = saver.Saver()
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
|
@ -202,7 +205,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
||||
num_streams=2,
|
||||
num_quantiles=self.num_quantiles,
|
||||
epsilon=self.eps,
|
||||
name="q0")
|
||||
save = saver.Saver()
|
||||
save.restore(sess, save_path)
|
||||
buckets = accumulator.get_bucket_boundaries()
|
||||
|
@ -215,7 +221,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||
|
||||
with self.cached_session() as sess:
|
||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
||||
num_streams=2,
|
||||
num_quantiles=self.num_quantiles,
|
||||
epsilon=self.eps,
|
||||
name="q0")
|
||||
|
||||
save = saver.Saver()
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
|
@ -233,7 +242,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
|||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
||||
num_streams=2,
|
||||
num_quantiles=self.num_quantiles,
|
||||
epsilon=self.eps,
|
||||
name="q0")
|
||||
save = saver.Saver()
|
||||
save.restore(sess, save_path)
|
||||
buckets = accumulator.get_bucket_boundaries()
|
||||
|
|
Loading…
Reference in New Issue