Add fixed space sparse class stats handling.

PiperOrigin-RevId: 169570470
This commit is contained in:
A. Unique TensorFlower 2017-09-21 11:18:05 -07:00 committed by TensorFlower Gardener
parent 2679dcfbaa
commit 054b88233b
5 changed files with 439 additions and 88 deletions

View File

@ -159,7 +159,7 @@ void ClassificationStats::AdditionalInitializationExample(
}
bool ClassificationStats::IsFinished() const {
bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1;
bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
return basic || finish_early_;
}
@ -193,8 +193,11 @@ void ClassificationStats::AddExample(
left_gini_->update(i, left_count(i, int_label), weight);
}
ClassificationAddLeftExample(i, int_label, weight);
} else if (right_gini_ != nullptr) {
right_gini_->update(i, right_count(i, int_label), weight);
} else {
if (right_gini_ != nullptr) {
right_gini_->update(i, right_count(i, int_label), weight);
}
ClassificationAddRightExample(i, int_label, weight);
}
}
@ -374,6 +377,41 @@ void ClassificationStats::CheckFinishEarlyBootstrap() {
finish_early_ = worst_g1 < best_g2;
}
bool ClassificationStats::BestSplit(SplitCandidate* best) const {
float min_score = FLT_MAX;
int best_index = -1;
float best_left_sum, best_right_sum;
// Calculate sums.
for (int i = 0; i < num_splits(); ++i) {
float left_sum, right_sum;
const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
// Find the lowest gini.
if (left_sum > 0 && right_sum > 0 &&
split_score < min_score) { // useless check
min_score = split_score;
best_index = i;
best_left_sum = left_sum;
best_right_sum = right_sum;
}
}
// This could happen if all the splits are useless.
if (best_index < 0) {
return false;
}
// Fill in stats to be used for leaf model.
*best->mutable_split() = splits_[best_index];
auto* left = best->mutable_left_stats();
left->set_weight_sum(best_left_sum);
auto* right = best->mutable_right_stats();
right->set_weight_sum(best_right_sum);
InitLeafClassStats(best_index, left, right);
return true;
}
// ------------------------ Dense Classification --------------------------- //
void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
Initialize();
@ -449,52 +487,20 @@ float DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
return left_score + right_score;
}
bool DenseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
float min_score = FLT_MAX;
int best_index = -1;
float best_left_sum, best_right_sum;
// Calculate sums.
for (int i = 0; i < num_splits(); ++i) {
float left_sum, right_sum;
const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
// Find the lowest gini.
if (left_sum > 0 && right_sum > 0 &&
split_score < min_score) { // useless check
min_score = split_score;
best_index = i;
best_left_sum = left_sum;
best_right_sum = right_sum;
}
}
// This could happen if all the splits are useless.
if (best_index < 0) {
return false;
}
// Fill in stats to be used for leaf model.
*best->mutable_split() = splits_[best_index];
// Left
auto* left = best->mutable_left_stats();
auto* left_class_stats = left->mutable_classification();
left->set_weight_sum(best_left_sum);
void DenseClassificationGrowStats::InitLeafClassStats(
int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
auto* left_class_stats = left_stats->mutable_classification();
auto* left_counts = left_class_stats->mutable_dense_counts();
for (int i = 0; i < params_.num_outputs(); ++i) {
left_counts->add_value()->set_float_value(
left_count(best_index, i));
left_counts->add_value()->set_float_value(left_count(best_split_index, i));
}
// Right
auto* right = best->mutable_right_stats();
auto* right_class_stats = right->mutable_classification();
right->set_weight_sum(best_right_sum);
auto* right_class_stats = right_stats->mutable_classification();
auto* right_counts = right_class_stats->mutable_dense_counts();
for (int i = 0; i < params_.num_outputs(); ++i) {
right_counts->add_value()->set_float_value(
total_counts_[i] - left_count(best_index, i));
right_counts->add_value()->set_float_value(total_counts_[i] -
left_count(best_split_index, i));
}
return true;
}
// ------------------------ Sparse Classification --------------------------- //
@ -584,49 +590,18 @@ float SparseClassificationGrowStats::GiniScore(
return left_score + right_score;
}
bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
float min_score = FLT_MAX;
int best_index = -1;
float best_left_sum = -1;
float best_right_sum = -1;
// Find the lowest gini.
for (int i = 0; i < num_splits(); ++i) {
float left_sum, right_sum;
const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
if (left_sum > 0 && right_sum > 0 &&
split_score < min_score) { // useless check
min_score = split_score;
best_index = i;
best_left_sum = left_sum;
best_right_sum = right_sum;
}
}
// This could happen if all the splits are useless.
if (best_index < 0) {
return false;
}
// Fill in stats to be used for leaf model.
*best->mutable_split() = splits_[best_index];
// Left
auto* left = best->mutable_left_stats();
auto* left_class_stats = left->mutable_classification();
left->set_weight_sum(best_left_sum);
void SparseClassificationGrowStats::InitLeafClassStats(
int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
auto* left_class_stats = left_stats->mutable_classification();
auto* left_counts =
left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
// Right
auto* right = best->mutable_right_stats();
auto* right_class_stats = right->mutable_classification();
right->set_weight_sum(best_right_sum);
auto* right_class_stats = right_stats->mutable_classification();
auto* right_counts =
right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
for (const auto& entry : total_counts_) {
auto it = left_counts_[best_index].find(entry.first);
if (it == left_counts_[best_index].end()) {
auto it = left_counts_[best_split_index].find(entry.first);
if (it == left_counts_[best_split_index].end()) {
(*right_counts)[entry.first].set_float_value(entry.second);
} else {
const float left = it->second;
@ -637,7 +612,184 @@ bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
}
}
}
return true;
}
// -------------------- FixedSizeClassStats --------------------------------- //
// FixedSizeClassStats implements the "SpaceSaving" algorithm by
// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi. See for example
// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
int argmin(const std::unordered_map<int, float>& m) {
int c = -1;
float f = FLT_MAX;
for (const auto it : m) {
if (it.second < f) {
f = it.second;
c = it.first;
}
}
return c;
}
void FixedSizeClassStats::accumulate(int c, float w) {
auto it = class_weights_.find(c);
if (it != class_weights_.end()) {
it->second += w;
if (c == smallest_weight_class_) {
smallest_weight_class_ = argmin(class_weights_);
}
return;
}
if (class_weights_.size() < n_) {
class_weights_.insert(it, std::pair<int, float>(c, w));
if (class_weights_.size() == n_) {
// Can't assume last added has the smallest weight, because the
// w's might be all different.
smallest_weight_class_ = argmin(class_weights_);
}
return;
}
// This is the slightly unintuitive heart of the SpaceSaving algorithm:
// if the map is full and we see a new class, we find the entry with the
// smallest weight and "take it over": we add our weight to its weight,
// and assign it all to the new seen class.
it = class_weights_.find(smallest_weight_class_);
float new_weight = it->second + w;
class_weights_.erase(it);
class_weights_[c] = new_weight;
smallest_weight_class_ = argmin(class_weights_);
}
float FixedSizeClassStats::get_weight(int c) const {
// Every entry in class_weights_ might be overstated by as much as the
// smallest_weight. We therefore assume that each has been overstated
// by smallest_weight / 2.0, and we re-distribute that mass over all
// num_classes_ classes.
float smallest_weight = 0.0;
auto it = class_weights_.find(smallest_weight_class_);
if (it != class_weights_.end()) {
smallest_weight = it->second;
}
float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
it = class_weights_.find(c);
if (it != class_weights_.end()) {
w += it->second - smallest_weight / 2.0;
}
return w;
}
void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
*sum = 0.0;
*square = 0.0;
float smallest_weight = 0.0;
auto it = class_weights_.find(smallest_weight_class_);
if (it != class_weights_.end()) {
smallest_weight = it->second;
}
float w;
for (const auto it : class_weights_) {
*sum += it.second;
w = get_weight(it.first);
*square += w * w;
}
w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
*square += (num_classes_ - n_) * w * w;
}
void FixedSizeClassStats::ExtractFromProto(
const decision_trees::SparseVector& sparse_vector) {
for (const auto& it : sparse_vector.sparse_value()) {
class_weights_[it.first] = it.second.float_value();
}
if (class_weights_.size() == n_) {
smallest_weight_class_ = argmin(class_weights_);
}
}
void FixedSizeClassStats::PackToProto(
decision_trees::SparseVector* sparse_vector) const {
for (const auto it : class_weights_) {
(*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
it.second);
}
}
// --------------------- FixedSizeSparseClassificationGrowStats ------------- //
void FixedSizeSparseClassificationGrowStats::ExtractFromProto(
const FertileSlot& slot) {
Initialize();
if (!slot.has_post_init_leaf_stats()) {
return;
}
weight_sum_ = slot.post_init_leaf_stats().weight_sum();
// Candidate counts and splits.
int split_num = 0;
left_counts_.clear();
right_counts_.clear();
for (const auto& cand : slot.candidates()) {
AddSplit(cand.split(), nullptr, nullptr, -1);
const auto& left_stats = cand.left_stats().classification().sparse_counts();
left_counts_.emplace_back(params_.num_classes_to_track(),
params_.num_outputs());
left_counts_[split_num].ExtractFromProto(left_stats);
const auto& right_stats =
cand.right_stats().classification().sparse_counts();
right_counts_.emplace_back(params_.num_classes_to_track(),
params_.num_outputs());
right_counts_[split_num].ExtractFromProto(right_stats);
++split_num;
}
}
void FixedSizeSparseClassificationGrowStats::PackToProto(
FertileSlot* slot) const {
auto* slot_stats = slot->mutable_post_init_leaf_stats();
slot_stats->set_weight_sum(weight_sum_);
for (int split_num = 0; split_num < num_splits(); ++split_num) {
auto* cand = slot->add_candidates();
*cand->mutable_split() = splits_[split_num];
auto* left_stats = cand->mutable_left_stats()
->mutable_classification()
->mutable_sparse_counts();
left_counts_[split_num].PackToProto(left_stats);
auto* right_stats = cand->mutable_right_stats()
->mutable_classification()
->mutable_sparse_counts();
right_counts_[split_num].PackToProto(right_stats);
}
}
float FixedSizeSparseClassificationGrowStats::GiniScore(
int split, float* left_sum, float* right_sum) const {
float left_square, right_square;
left_counts_[split].set_sum_and_square(left_sum, &left_square);
right_counts_[split].set_sum_and_square(right_sum, &right_square);
const int32 num_classes = params_.num_outputs();
const float left_score =
WeightedSmoothedGini(*left_sum, left_square, num_classes);
const float right_score =
WeightedSmoothedGini(*right_sum, right_square, num_classes);
return left_score + right_score;
}
void FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
auto* left_class_stats = left_stats->mutable_classification();
auto* left_counts = left_class_stats->mutable_sparse_counts();
left_counts_[best_split_index].PackToProto(left_counts);
auto* right_class_stats = right_stats->mutable_classification();
auto* right_counts = right_class_stats->mutable_sparse_counts();
right_counts_[best_split_index].PackToProto(right_counts);
}
// --------------------- Least Squares Regression --------------------------- //

View File

@ -189,15 +189,29 @@ class ClassificationStats : public GrowStats {
half_initialized_splits_.empty());
}
bool BestSplit(SplitCandidate* best) const override;
// When best_split_index has been chosen as the best split,
// InitLeafClassStats is used to initialize the LeafStat's of the two
// new leaves.
virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
LeafStat* right_stats) const = 0;
protected:
virtual float GiniScore(int split, float* left_sum,
float* right_sum) const = 0;
virtual int num_outputs_seen() const = 0;
// is_pure should return true if at most one class label has been seen
// at the node, and false if two or more have been seen.
virtual bool is_pure() const = 0;
virtual float left_count(int split, int class_num) const = 0;
virtual float right_count(int split, int class_num) const = 0;
virtual void ClassificationAddLeftExample(
int split, int64 int_label, float weight) = 0;
virtual void ClassificationAddRightExample(int split, int64 int_label,
float weight) {
// Does nothing by default, but sub-classes can override.
}
virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0;
virtual void ClassificationAddSplitStats() = 0;
@ -301,7 +315,8 @@ class DenseClassificationGrowStats : public ClassificationStats {
void ExtractFromProto(const FertileSlot& slot) override;
void PackToProto(FertileSlot* slot) const override;
bool BestSplit(SplitCandidate* best) const override;
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
LeafStat* right_stats) const;
protected:
void ClassificationAddSplitStats() override {
@ -317,9 +332,7 @@ class DenseClassificationGrowStats : public ClassificationStats {
num_outputs_seen_ = 0;
}
int num_outputs_seen() const override {
return num_outputs_seen_;
}
bool is_pure() const override { return num_outputs_seen_ <= 1; }
void ClassificationAddLeftExample(int split, int64 int_label,
float weight) override {
@ -369,7 +382,8 @@ class SparseClassificationGrowStats : public ClassificationStats {
void ExtractFromProto(const FertileSlot& slot) override;
void PackToProto(FertileSlot* slot) const override;
bool BestSplit(SplitCandidate* best) const override;
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
LeafStat* right_stats) const;
protected:
void ClassificationAddSplitStats() override {
@ -384,7 +398,7 @@ class SparseClassificationGrowStats : public ClassificationStats {
left_counts_.clear();
}
int num_outputs_seen() const override { return total_counts_.size(); }
bool is_pure() const override { return total_counts_.size() <= 1; }
void ClassificationAddLeftExample(int split, int64 int_label,
float weight) override {
@ -412,6 +426,111 @@ class SparseClassificationGrowStats : public ClassificationStats {
std::vector<std::unordered_map<int, float>> left_counts_;
};
// Accumulates weights for the most popular classes while only using a
// fixed amount of space.
class FixedSizeClassStats {
public:
// n specifies how many classes are tracked.
FixedSizeClassStats(int n, int num_classes)
: n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {}
// Add weight w to the class c.
void accumulate(int c, float w);
// Return the approximate accumulated weight for class c. If c isn't one
// of the n-most popular classes, this can be 0 even if c has accumulated
// some weight.
float get_weight(int c) const;
// Put the sum of all weights seen into *sum, and
// \sum_c get_weight(c)^2
// into *square. *sum will be exact, but *square will be approximate.
void set_sum_and_square(float* sum, float* square) const;
void ExtractFromProto(const decision_trees::SparseVector& sparse_vector);
void PackToProto(decision_trees::SparseVector* sparse_vector) const;
private:
// For our typical use cases, n_ is between 10 and 100, so there's no
// need to track the smallest weight with a min_heap or the like.
int n_;
int num_classes_;
// This tracks the class of the smallest weight, but isn't set until
// class_weights_.size() == n_.
int smallest_weight_class_;
std::unordered_map<int, float> class_weights_;
};
// Tracks classification stats sparsely in a fixed amount of space.
class FixedSizeSparseClassificationGrowStats : public ClassificationStats {
public:
FixedSizeSparseClassificationGrowStats(const TensorForestParams& params,
int32 depth)
: ClassificationStats(params, depth) {}
void Initialize() override { Clear(); }
void ExtractFromProto(const FertileSlot& slot) override;
void PackToProto(FertileSlot* slot) const override;
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
LeafStat* right_stats) const;
protected:
void ClassificationAddSplitStats() override {
FixedSizeClassStats stats(params_.num_classes_to_track(),
params_.num_outputs());
left_counts_.resize(num_splits(), stats);
right_counts_.resize(num_splits(), stats);
}
void ClassificationRemoveSplitStats(int split_num) override {
left_counts_.erase(left_counts_.begin() + split_num,
left_counts_.begin() + (split_num + 1));
right_counts_.erase(right_counts_.begin() + split_num,
right_counts_.begin() + (split_num + 1));
}
void ClearInternal() override {
left_counts_.clear();
right_counts_.clear();
}
bool is_pure() const override { return first_two_classes_seen_.size() <= 1; }
void ClassificationAddLeftExample(int split, int64 int_label,
float weight) override {
left_counts_[split].accumulate(int_label, weight);
}
void ClassificationAddRightExample(int split, int64 int_label,
float weight) override {
right_counts_[split].accumulate(int_label, weight);
}
void ClassificationAddTotalExample(int64 int_label, float weight) override {
if (is_pure()) {
first_two_classes_seen_.insert(int_label);
}
}
float GiniScore(int split, float* left_sum, float* right_sum) const override;
float left_count(int split, int class_num) const override {
return left_counts_[split].get_weight(class_num);
}
float right_count(int split, int class_num) const override {
return right_counts_[split].get_weight(class_num);
}
private:
std::vector<FixedSizeClassStats> left_counts_;
std::vector<FixedSizeClassStats> right_counts_;
// We keep track of the first two class labels seen, so we can tell if
// the node is pure (= all of one class) or not.
std::set<int> first_two_classes_seen_;
};
// Tracks regression stats using least-squares minimization.
class LeastSquaresRegressionGrowStats : public GrowStats {
public:

View File

@ -29,6 +29,8 @@ using tensorflow::tensorforest::TestableInputTarget;
using tensorflow::tensorforest::FertileSlot;
using tensorflow::tensorforest::DenseClassificationGrowStats;
using tensorflow::tensorforest::SparseClassificationGrowStats;
using tensorflow::tensorforest::FixedSizeClassStats;
using tensorflow::tensorforest::FixedSizeSparseClassificationGrowStats;
using tensorflow::tensorforest::LeastSquaresRegressionGrowStats;
using tensorflow::tensorforest::TensorForestParams;
using tensorflow::tensorforest::SPLIT_FINISH_BASIC;
@ -327,7 +329,6 @@ TEST(GrowStatsLeastSquaresRegressionTest, Basic) {
ASSERT_EQ(serialized_again, serialized);
}
TEST(GrowStatsSparseClassificationTest, Basic) {
TensorForestParams params;
params.set_num_outputs(2);
@ -360,5 +361,74 @@ TEST(GrowStatsSparseClassificationTest, Basic) {
ASSERT_EQ(serialized_again, serialized);
}
TEST(FixedSizeClassStats, Exact) {
FixedSizeClassStats stats(10, 100);
stats.accumulate(1, 1.0);
stats.accumulate(2, 2.0);
stats.accumulate(3, 3.0);
EXPECT_EQ(stats.get_weight(1), 1.0);
EXPECT_EQ(stats.get_weight(2), 2.0);
EXPECT_EQ(stats.get_weight(3), 3.0);
float sum;
float square;
stats.set_sum_and_square(&sum, &square);
EXPECT_EQ(sum, 6.0);
EXPECT_EQ(square, 14.0);
}
TEST(FixedSizeClassStats, Approximate) {
FixedSizeClassStats stats(5, 10);
for (int i = 1; i <= 10; i++) {
stats.accumulate(i, i * 1.0);
}
// We should be off by no more than *half* of the least weight
// in the class_weights_, which is 7.
float tolerance = 3.5;
for (int i = 1; i <= 10; i++) {
float diff = stats.get_weight(i) - i * 1.0;
EXPECT_LE(diff, tolerance);
EXPECT_GE(diff, -tolerance);
}
}
TEST(GrowStatsFixedSizeSparseClassificationTest, Basic) {
TensorForestParams params;
params.set_num_outputs(2);
params.set_num_classes_to_track(5);
params.mutable_split_after_samples()->set_constant_value(2);
params.mutable_num_splits_to_consider()->set_constant_value(2);
std::unique_ptr<FixedSizeSparseClassificationGrowStats> stat(
new FixedSizeSparseClassificationGrowStats(params, 1));
stat->Initialize();
std::vector<float> labels = {100, 1000, 1};
std::vector<float> weights = {2.3, 20.3, 1.1};
std::unique_ptr<TestableInputTarget> target(
new TestableInputTarget(labels, weights, 1));
std::vector<int> branches = {1, 0, 1, 1, 0, 0};
RunBatch(stat.get(), target.get());
CHECK(stat->IsFinished());
FertileSlot slot;
stat->PackToProto(&slot);
string serialized = slot.DebugString();
std::unique_ptr<FixedSizeSparseClassificationGrowStats> new_stat(
new FixedSizeSparseClassificationGrowStats(params, 1));
new_stat->ExtractFromProto(slot);
FertileSlot second_one;
new_stat->PackToProto(&second_one);
string serialized_again = second_one.DebugString();
ASSERT_EQ(serialized_again, serialized);
}
} // namespace
} // namespace tensorflow

View File

@ -55,6 +55,10 @@ std::unique_ptr<GrowStats> SplitCollectionOperator::CreateGrowStats(
return std::unique_ptr<GrowStats>(new LeastSquaresRegressionGrowStats(
params_, depth));
case STATS_FIXED_SIZE_SPARSE_GINI:
return std::unique_ptr<GrowStats>(
new FixedSizeSparseClassificationGrowStats(params_, depth));
default:
LOG(ERROR) << "Unknown grow stats type: " << params_.stats_type();
return nullptr;

View File

@ -20,7 +20,9 @@ enum StatsModelType {
STATS_DENSE_GINI = 0;
STATS_SPARSE_GINI = 1;
STATS_LEAST_SQUARES_REGRESSION = 2;
// STATS_SPARSE_THEN_DENSE_GINI is deprecated and no longer supported.
STATS_SPARSE_THEN_DENSE_GINI = 3;
STATS_FIXED_SIZE_SPARSE_GINI = 4;
}
// Allows selection of operations on the collection of split candidates.
@ -145,4 +147,8 @@ message TensorForestParams {
// --------- Parameters for experimental features ---------------------- //
string graph_dir = 16;
int32 num_select_features = 17;
// When using a FixedSizeSparseClassificationGrowStats, keep track of
// this many classes.
int32 num_classes_to_track = 24;
}