Improve the bucket generation in Boosted Trees to avoid returning more than requested buckets.
PiperOrigin-RevId: 312371738 Change-Id: I7f241c839f52d679ad4ceb82c161018e9b944fa3
This commit is contained in:
parent
d894109fe1
commit
91da977a03
|
@ -16,6 +16,7 @@
|
||||||
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
|
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
|
||||||
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <list>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
|
#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
|
||||||
|
@ -250,10 +251,37 @@ class WeightedQuantilesSummary {
|
||||||
float compression_eps = ApproximationError() + (1.0 / num_boundaries);
|
float compression_eps = ApproximationError() + (1.0 / num_boundaries);
|
||||||
compressed_summary.Compress(num_boundaries, compression_eps);
|
compressed_summary.Compress(num_boundaries, compression_eps);
|
||||||
|
|
||||||
|
// Remove the least important boundaries by the gap removing them would
|
||||||
|
// create.
|
||||||
|
std::list<int64> boundaries_to_keep;
|
||||||
|
for (int64 i = 0; i != compressed_summary.entries_.size(); ++i) {
|
||||||
|
boundaries_to_keep.push_back(i);
|
||||||
|
}
|
||||||
|
while (boundaries_to_keep.size() > num_boundaries) {
|
||||||
|
std::list<int64>::iterator min_element = boundaries_to_keep.end();
|
||||||
|
auto prev = boundaries_to_keep.begin();
|
||||||
|
auto curr = prev;
|
||||||
|
++curr;
|
||||||
|
auto next = curr;
|
||||||
|
++next;
|
||||||
|
WeightType min_weight = TotalWeight();
|
||||||
|
for (; next != boundaries_to_keep.end(); ++prev, ++curr, ++next) {
|
||||||
|
WeightType new_weight =
|
||||||
|
compressed_summary.entries_[*next].PrevMaxRank() -
|
||||||
|
compressed_summary.entries_[*prev].NextMinRank();
|
||||||
|
if (new_weight < min_weight) {
|
||||||
|
min_element = curr;
|
||||||
|
min_weight = new_weight;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
boundaries_to_keep.erase(min_element);
|
||||||
|
}
|
||||||
|
|
||||||
// Return boundaries.
|
// Return boundaries.
|
||||||
output.reserve(compressed_summary.entries_.size());
|
output.reserve(boundaries_to_keep.size());
|
||||||
for (const auto& entry : compressed_summary.entries_) {
|
for (auto itr = boundaries_to_keep.begin(); itr != boundaries_to_keep.end();
|
||||||
output.push_back(entry.value);
|
++itr) {
|
||||||
|
output.push_back(compressed_summary.entries_[*itr].value);
|
||||||
}
|
}
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,7 +82,7 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
self.eps = 0.01
|
self.eps = 0.01
|
||||||
self.max_elements = 1 << 16
|
self.max_elements = 1 << 16
|
||||||
self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
|
self.num_quantiles = constant_op.constant(4, dtype=dtypes.int64)
|
||||||
|
|
||||||
def testBasicQuantileBucketsSingleResource(self):
|
def testBasicQuantileBucketsSingleResource(self):
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
|
@ -183,7 +183,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
num_streams=2,
|
||||||
|
num_quantiles=self.num_quantiles,
|
||||||
|
epsilon=self.eps,
|
||||||
|
name="q0")
|
||||||
|
|
||||||
save = saver.Saver()
|
save = saver.Saver()
|
||||||
resources.initialize_resources(resources.shared_resources()).run()
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
|
@ -202,7 +205,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
num_streams=2,
|
||||||
|
num_quantiles=self.num_quantiles,
|
||||||
|
epsilon=self.eps,
|
||||||
|
name="q0")
|
||||||
save = saver.Saver()
|
save = saver.Saver()
|
||||||
save.restore(sess, save_path)
|
save.restore(sess, save_path)
|
||||||
buckets = accumulator.get_bucket_boundaries()
|
buckets = accumulator.get_bucket_boundaries()
|
||||||
|
@ -215,7 +221,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
num_streams=2,
|
||||||
|
num_quantiles=self.num_quantiles,
|
||||||
|
epsilon=self.eps,
|
||||||
|
name="q0")
|
||||||
|
|
||||||
save = saver.Saver()
|
save = saver.Saver()
|
||||||
resources.initialize_resources(resources.shared_resources()).run()
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
|
@ -233,7 +242,10 @@ class QuantileOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
accumulator = boosted_trees_ops.QuantileAccumulator(
|
accumulator = boosted_trees_ops.QuantileAccumulator(
|
||||||
num_streams=2, num_quantiles=3, epsilon=self.eps, name="q0")
|
num_streams=2,
|
||||||
|
num_quantiles=self.num_quantiles,
|
||||||
|
epsilon=self.eps,
|
||||||
|
name="q0")
|
||||||
save = saver.Saver()
|
save = saver.Saver()
|
||||||
save.restore(sess, save_path)
|
save.restore(sess, save_path)
|
||||||
buckets = accumulator.get_bucket_boundaries()
|
buckets = accumulator.get_bucket_boundaries()
|
||||||
|
|
Loading…
Reference in New Issue