Add fixed space sparse class stats handling.
PiperOrigin-RevId: 169570470
This commit is contained in:
parent
2679dcfbaa
commit
054b88233b
@ -159,7 +159,7 @@ void ClassificationStats::AdditionalInitializationExample(
|
||||
}
|
||||
|
||||
bool ClassificationStats::IsFinished() const {
|
||||
bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1;
|
||||
bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
|
||||
return basic || finish_early_;
|
||||
}
|
||||
|
||||
@ -193,8 +193,11 @@ void ClassificationStats::AddExample(
|
||||
left_gini_->update(i, left_count(i, int_label), weight);
|
||||
}
|
||||
ClassificationAddLeftExample(i, int_label, weight);
|
||||
} else if (right_gini_ != nullptr) {
|
||||
right_gini_->update(i, right_count(i, int_label), weight);
|
||||
} else {
|
||||
if (right_gini_ != nullptr) {
|
||||
right_gini_->update(i, right_count(i, int_label), weight);
|
||||
}
|
||||
ClassificationAddRightExample(i, int_label, weight);
|
||||
}
|
||||
}
|
||||
|
||||
@ -374,6 +377,41 @@ void ClassificationStats::CheckFinishEarlyBootstrap() {
|
||||
finish_early_ = worst_g1 < best_g2;
|
||||
}
|
||||
|
||||
bool ClassificationStats::BestSplit(SplitCandidate* best) const {
|
||||
float min_score = FLT_MAX;
|
||||
int best_index = -1;
|
||||
float best_left_sum, best_right_sum;
|
||||
|
||||
// Calculate sums.
|
||||
for (int i = 0; i < num_splits(); ++i) {
|
||||
float left_sum, right_sum;
|
||||
const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
|
||||
// Find the lowest gini.
|
||||
if (left_sum > 0 && right_sum > 0 &&
|
||||
split_score < min_score) { // useless check
|
||||
min_score = split_score;
|
||||
best_index = i;
|
||||
best_left_sum = left_sum;
|
||||
best_right_sum = right_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// This could happen if all the splits are useless.
|
||||
if (best_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fill in stats to be used for leaf model.
|
||||
*best->mutable_split() = splits_[best_index];
|
||||
auto* left = best->mutable_left_stats();
|
||||
left->set_weight_sum(best_left_sum);
|
||||
auto* right = best->mutable_right_stats();
|
||||
right->set_weight_sum(best_right_sum);
|
||||
InitLeafClassStats(best_index, left, right);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// ------------------------ Dense Classification --------------------------- //
|
||||
void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
|
||||
Initialize();
|
||||
@ -449,52 +487,20 @@ float DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
|
||||
return left_score + right_score;
|
||||
}
|
||||
|
||||
bool DenseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
|
||||
float min_score = FLT_MAX;
|
||||
int best_index = -1;
|
||||
float best_left_sum, best_right_sum;
|
||||
|
||||
// Calculate sums.
|
||||
for (int i = 0; i < num_splits(); ++i) {
|
||||
float left_sum, right_sum;
|
||||
const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
|
||||
// Find the lowest gini.
|
||||
if (left_sum > 0 && right_sum > 0 &&
|
||||
split_score < min_score) { // useless check
|
||||
min_score = split_score;
|
||||
best_index = i;
|
||||
best_left_sum = left_sum;
|
||||
best_right_sum = right_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// This could happen if all the splits are useless.
|
||||
if (best_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fill in stats to be used for leaf model.
|
||||
*best->mutable_split() = splits_[best_index];
|
||||
// Left
|
||||
auto* left = best->mutable_left_stats();
|
||||
auto* left_class_stats = left->mutable_classification();
|
||||
left->set_weight_sum(best_left_sum);
|
||||
void DenseClassificationGrowStats::InitLeafClassStats(
|
||||
int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
|
||||
auto* left_class_stats = left_stats->mutable_classification();
|
||||
auto* left_counts = left_class_stats->mutable_dense_counts();
|
||||
for (int i = 0; i < params_.num_outputs(); ++i) {
|
||||
left_counts->add_value()->set_float_value(
|
||||
left_count(best_index, i));
|
||||
left_counts->add_value()->set_float_value(left_count(best_split_index, i));
|
||||
}
|
||||
|
||||
// Right
|
||||
auto* right = best->mutable_right_stats();
|
||||
auto* right_class_stats = right->mutable_classification();
|
||||
right->set_weight_sum(best_right_sum);
|
||||
auto* right_class_stats = right_stats->mutable_classification();
|
||||
auto* right_counts = right_class_stats->mutable_dense_counts();
|
||||
for (int i = 0; i < params_.num_outputs(); ++i) {
|
||||
right_counts->add_value()->set_float_value(
|
||||
total_counts_[i] - left_count(best_index, i));
|
||||
right_counts->add_value()->set_float_value(total_counts_[i] -
|
||||
left_count(best_split_index, i));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// ------------------------ Sparse Classification --------------------------- //
|
||||
@ -584,49 +590,18 @@ float SparseClassificationGrowStats::GiniScore(
|
||||
return left_score + right_score;
|
||||
}
|
||||
|
||||
bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
|
||||
float min_score = FLT_MAX;
|
||||
int best_index = -1;
|
||||
float best_left_sum = -1;
|
||||
float best_right_sum = -1;
|
||||
|
||||
// Find the lowest gini.
|
||||
for (int i = 0; i < num_splits(); ++i) {
|
||||
float left_sum, right_sum;
|
||||
const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
|
||||
if (left_sum > 0 && right_sum > 0 &&
|
||||
split_score < min_score) { // useless check
|
||||
min_score = split_score;
|
||||
best_index = i;
|
||||
best_left_sum = left_sum;
|
||||
best_right_sum = right_sum;
|
||||
}
|
||||
}
|
||||
|
||||
// This could happen if all the splits are useless.
|
||||
if (best_index < 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fill in stats to be used for leaf model.
|
||||
*best->mutable_split() = splits_[best_index];
|
||||
// Left
|
||||
auto* left = best->mutable_left_stats();
|
||||
auto* left_class_stats = left->mutable_classification();
|
||||
left->set_weight_sum(best_left_sum);
|
||||
void SparseClassificationGrowStats::InitLeafClassStats(
|
||||
int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
|
||||
auto* left_class_stats = left_stats->mutable_classification();
|
||||
auto* left_counts =
|
||||
left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
|
||||
|
||||
// Right
|
||||
auto* right = best->mutable_right_stats();
|
||||
auto* right_class_stats = right->mutable_classification();
|
||||
right->set_weight_sum(best_right_sum);
|
||||
auto* right_class_stats = right_stats->mutable_classification();
|
||||
auto* right_counts =
|
||||
right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
|
||||
|
||||
for (const auto& entry : total_counts_) {
|
||||
auto it = left_counts_[best_index].find(entry.first);
|
||||
if (it == left_counts_[best_index].end()) {
|
||||
auto it = left_counts_[best_split_index].find(entry.first);
|
||||
if (it == left_counts_[best_split_index].end()) {
|
||||
(*right_counts)[entry.first].set_float_value(entry.second);
|
||||
} else {
|
||||
const float left = it->second;
|
||||
@ -637,7 +612,184 @@ bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// -------------------- FixedSizeClassStats --------------------------------- //
|
||||
|
||||
// FixedSizeClassStats implements the "SpaceSaving" algorithm by
|
||||
// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi. See for example
|
||||
// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
|
||||
|
||||
int argmin(const std::unordered_map<int, float>& m) {
|
||||
int c = -1;
|
||||
float f = FLT_MAX;
|
||||
for (const auto it : m) {
|
||||
if (it.second < f) {
|
||||
f = it.second;
|
||||
c = it.first;
|
||||
}
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
void FixedSizeClassStats::accumulate(int c, float w) {
|
||||
auto it = class_weights_.find(c);
|
||||
if (it != class_weights_.end()) {
|
||||
it->second += w;
|
||||
if (c == smallest_weight_class_) {
|
||||
smallest_weight_class_ = argmin(class_weights_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (class_weights_.size() < n_) {
|
||||
class_weights_.insert(it, std::pair<int, float>(c, w));
|
||||
if (class_weights_.size() == n_) {
|
||||
// Can't assume last added has the smallest weight, because the
|
||||
// w's might be all different.
|
||||
smallest_weight_class_ = argmin(class_weights_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// This is the slightly unintuitive heart of the SpaceSaving algorithm:
|
||||
// if the map is full and we see a new class, we find the entry with the
|
||||
// smallest weight and "take it over": we add our weight to its weight,
|
||||
// and assign it all to the new seen class.
|
||||
it = class_weights_.find(smallest_weight_class_);
|
||||
float new_weight = it->second + w;
|
||||
class_weights_.erase(it);
|
||||
class_weights_[c] = new_weight;
|
||||
smallest_weight_class_ = argmin(class_weights_);
|
||||
}
|
||||
|
||||
float FixedSizeClassStats::get_weight(int c) const {
|
||||
// Every entry in class_weights_ might be overstated by as much as the
|
||||
// smallest_weight. We therefore assume that each has been overstated
|
||||
// by smallest_weight / 2.0, and we re-distribute that mass over all
|
||||
// num_classes_ classes.
|
||||
float smallest_weight = 0.0;
|
||||
auto it = class_weights_.find(smallest_weight_class_);
|
||||
if (it != class_weights_.end()) {
|
||||
smallest_weight = it->second;
|
||||
}
|
||||
float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
|
||||
it = class_weights_.find(c);
|
||||
if (it != class_weights_.end()) {
|
||||
w += it->second - smallest_weight / 2.0;
|
||||
}
|
||||
return w;
|
||||
}
|
||||
|
||||
void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
|
||||
*sum = 0.0;
|
||||
*square = 0.0;
|
||||
|
||||
float smallest_weight = 0.0;
|
||||
auto it = class_weights_.find(smallest_weight_class_);
|
||||
if (it != class_weights_.end()) {
|
||||
smallest_weight = it->second;
|
||||
}
|
||||
|
||||
float w;
|
||||
for (const auto it : class_weights_) {
|
||||
*sum += it.second;
|
||||
w = get_weight(it.first);
|
||||
*square += w * w;
|
||||
}
|
||||
|
||||
w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
|
||||
*square += (num_classes_ - n_) * w * w;
|
||||
}
|
||||
|
||||
void FixedSizeClassStats::ExtractFromProto(
|
||||
const decision_trees::SparseVector& sparse_vector) {
|
||||
for (const auto& it : sparse_vector.sparse_value()) {
|
||||
class_weights_[it.first] = it.second.float_value();
|
||||
}
|
||||
if (class_weights_.size() == n_) {
|
||||
smallest_weight_class_ = argmin(class_weights_);
|
||||
}
|
||||
}
|
||||
|
||||
void FixedSizeClassStats::PackToProto(
|
||||
decision_trees::SparseVector* sparse_vector) const {
|
||||
for (const auto it : class_weights_) {
|
||||
(*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
|
||||
it.second);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------- FixedSizeSparseClassificationGrowStats ------------- //
|
||||
|
||||
void FixedSizeSparseClassificationGrowStats::ExtractFromProto(
|
||||
const FertileSlot& slot) {
|
||||
Initialize();
|
||||
if (!slot.has_post_init_leaf_stats()) {
|
||||
return;
|
||||
}
|
||||
weight_sum_ = slot.post_init_leaf_stats().weight_sum();
|
||||
|
||||
// Candidate counts and splits.
|
||||
int split_num = 0;
|
||||
left_counts_.clear();
|
||||
right_counts_.clear();
|
||||
for (const auto& cand : slot.candidates()) {
|
||||
AddSplit(cand.split(), nullptr, nullptr, -1);
|
||||
const auto& left_stats = cand.left_stats().classification().sparse_counts();
|
||||
left_counts_.emplace_back(params_.num_classes_to_track(),
|
||||
params_.num_outputs());
|
||||
left_counts_[split_num].ExtractFromProto(left_stats);
|
||||
const auto& right_stats =
|
||||
cand.right_stats().classification().sparse_counts();
|
||||
right_counts_.emplace_back(params_.num_classes_to_track(),
|
||||
params_.num_outputs());
|
||||
right_counts_[split_num].ExtractFromProto(right_stats);
|
||||
++split_num;
|
||||
}
|
||||
}
|
||||
|
||||
void FixedSizeSparseClassificationGrowStats::PackToProto(
|
||||
FertileSlot* slot) const {
|
||||
auto* slot_stats = slot->mutable_post_init_leaf_stats();
|
||||
slot_stats->set_weight_sum(weight_sum_);
|
||||
|
||||
for (int split_num = 0; split_num < num_splits(); ++split_num) {
|
||||
auto* cand = slot->add_candidates();
|
||||
*cand->mutable_split() = splits_[split_num];
|
||||
auto* left_stats = cand->mutable_left_stats()
|
||||
->mutable_classification()
|
||||
->mutable_sparse_counts();
|
||||
left_counts_[split_num].PackToProto(left_stats);
|
||||
auto* right_stats = cand->mutable_right_stats()
|
||||
->mutable_classification()
|
||||
->mutable_sparse_counts();
|
||||
right_counts_[split_num].PackToProto(right_stats);
|
||||
}
|
||||
}
|
||||
|
||||
float FixedSizeSparseClassificationGrowStats::GiniScore(
|
||||
int split, float* left_sum, float* right_sum) const {
|
||||
float left_square, right_square;
|
||||
left_counts_[split].set_sum_and_square(left_sum, &left_square);
|
||||
right_counts_[split].set_sum_and_square(right_sum, &right_square);
|
||||
const int32 num_classes = params_.num_outputs();
|
||||
const float left_score =
|
||||
WeightedSmoothedGini(*left_sum, left_square, num_classes);
|
||||
const float right_score =
|
||||
WeightedSmoothedGini(*right_sum, right_square, num_classes);
|
||||
return left_score + right_score;
|
||||
}
|
||||
|
||||
void FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
|
||||
int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
|
||||
auto* left_class_stats = left_stats->mutable_classification();
|
||||
auto* left_counts = left_class_stats->mutable_sparse_counts();
|
||||
left_counts_[best_split_index].PackToProto(left_counts);
|
||||
|
||||
auto* right_class_stats = right_stats->mutable_classification();
|
||||
auto* right_counts = right_class_stats->mutable_sparse_counts();
|
||||
right_counts_[best_split_index].PackToProto(right_counts);
|
||||
}
|
||||
|
||||
// --------------------- Least Squares Regression --------------------------- //
|
||||
|
@ -189,15 +189,29 @@ class ClassificationStats : public GrowStats {
|
||||
half_initialized_splits_.empty());
|
||||
}
|
||||
|
||||
bool BestSplit(SplitCandidate* best) const override;
|
||||
// When best_split_index has been chosen as the best split,
|
||||
// InitLeafClassStats is used to initialize the LeafStat's of the two
|
||||
// new leaves.
|
||||
virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
|
||||
LeafStat* right_stats) const = 0;
|
||||
|
||||
protected:
|
||||
virtual float GiniScore(int split, float* left_sum,
|
||||
float* right_sum) const = 0;
|
||||
virtual int num_outputs_seen() const = 0;
|
||||
|
||||
// is_pure should return true if at most one class label has been seen
|
||||
// at the node, and false if two or more have been seen.
|
||||
virtual bool is_pure() const = 0;
|
||||
virtual float left_count(int split, int class_num) const = 0;
|
||||
virtual float right_count(int split, int class_num) const = 0;
|
||||
|
||||
virtual void ClassificationAddLeftExample(
|
||||
int split, int64 int_label, float weight) = 0;
|
||||
virtual void ClassificationAddRightExample(int split, int64 int_label,
|
||||
float weight) {
|
||||
// Does nothing by default, but sub-classes can override.
|
||||
}
|
||||
virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0;
|
||||
|
||||
virtual void ClassificationAddSplitStats() = 0;
|
||||
@ -301,7 +315,8 @@ class DenseClassificationGrowStats : public ClassificationStats {
|
||||
void ExtractFromProto(const FertileSlot& slot) override;
|
||||
void PackToProto(FertileSlot* slot) const override;
|
||||
|
||||
bool BestSplit(SplitCandidate* best) const override;
|
||||
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
|
||||
LeafStat* right_stats) const;
|
||||
|
||||
protected:
|
||||
void ClassificationAddSplitStats() override {
|
||||
@ -317,9 +332,7 @@ class DenseClassificationGrowStats : public ClassificationStats {
|
||||
num_outputs_seen_ = 0;
|
||||
}
|
||||
|
||||
int num_outputs_seen() const override {
|
||||
return num_outputs_seen_;
|
||||
}
|
||||
bool is_pure() const override { return num_outputs_seen_ <= 1; }
|
||||
|
||||
void ClassificationAddLeftExample(int split, int64 int_label,
|
||||
float weight) override {
|
||||
@ -369,7 +382,8 @@ class SparseClassificationGrowStats : public ClassificationStats {
|
||||
void ExtractFromProto(const FertileSlot& slot) override;
|
||||
void PackToProto(FertileSlot* slot) const override;
|
||||
|
||||
bool BestSplit(SplitCandidate* best) const override;
|
||||
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
|
||||
LeafStat* right_stats) const;
|
||||
|
||||
protected:
|
||||
void ClassificationAddSplitStats() override {
|
||||
@ -384,7 +398,7 @@ class SparseClassificationGrowStats : public ClassificationStats {
|
||||
left_counts_.clear();
|
||||
}
|
||||
|
||||
int num_outputs_seen() const override { return total_counts_.size(); }
|
||||
bool is_pure() const override { return total_counts_.size() <= 1; }
|
||||
|
||||
void ClassificationAddLeftExample(int split, int64 int_label,
|
||||
float weight) override {
|
||||
@ -412,6 +426,111 @@ class SparseClassificationGrowStats : public ClassificationStats {
|
||||
std::vector<std::unordered_map<int, float>> left_counts_;
|
||||
};
|
||||
|
||||
// Accumulates weights for the most popular classes while only using a
|
||||
// fixed amount of space.
|
||||
class FixedSizeClassStats {
|
||||
public:
|
||||
// n specifies how many classes are tracked.
|
||||
FixedSizeClassStats(int n, int num_classes)
|
||||
: n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {}
|
||||
|
||||
// Add weight w to the class c.
|
||||
void accumulate(int c, float w);
|
||||
|
||||
// Return the approximate accumulated weight for class c. If c isn't one
|
||||
// of the n-most popular classes, this can be 0 even if c has accumulated
|
||||
// some weight.
|
||||
float get_weight(int c) const;
|
||||
|
||||
// Put the sum of all weights seen into *sum, and
|
||||
// \sum_c get_weight(c)^2
|
||||
// into *square. *sum will be exact, but *square will be approximate.
|
||||
void set_sum_and_square(float* sum, float* square) const;
|
||||
|
||||
void ExtractFromProto(const decision_trees::SparseVector& sparse_vector);
|
||||
void PackToProto(decision_trees::SparseVector* sparse_vector) const;
|
||||
|
||||
private:
|
||||
// For our typical use cases, n_ is between 10 and 100, so there's no
|
||||
// need to track the smallest weight with a min_heap or the like.
|
||||
int n_;
|
||||
int num_classes_;
|
||||
|
||||
// This tracks the class of the smallest weight, but isn't set until
|
||||
// class_weights_.size() == n_.
|
||||
int smallest_weight_class_;
|
||||
|
||||
std::unordered_map<int, float> class_weights_;
|
||||
};
|
||||
|
||||
// Tracks classification stats sparsely in a fixed amount of space.
|
||||
class FixedSizeSparseClassificationGrowStats : public ClassificationStats {
|
||||
public:
|
||||
FixedSizeSparseClassificationGrowStats(const TensorForestParams& params,
|
||||
int32 depth)
|
||||
: ClassificationStats(params, depth) {}
|
||||
|
||||
void Initialize() override { Clear(); }
|
||||
|
||||
void ExtractFromProto(const FertileSlot& slot) override;
|
||||
void PackToProto(FertileSlot* slot) const override;
|
||||
|
||||
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
|
||||
LeafStat* right_stats) const;
|
||||
|
||||
protected:
|
||||
void ClassificationAddSplitStats() override {
|
||||
FixedSizeClassStats stats(params_.num_classes_to_track(),
|
||||
params_.num_outputs());
|
||||
left_counts_.resize(num_splits(), stats);
|
||||
right_counts_.resize(num_splits(), stats);
|
||||
}
|
||||
void ClassificationRemoveSplitStats(int split_num) override {
|
||||
left_counts_.erase(left_counts_.begin() + split_num,
|
||||
left_counts_.begin() + (split_num + 1));
|
||||
right_counts_.erase(right_counts_.begin() + split_num,
|
||||
right_counts_.begin() + (split_num + 1));
|
||||
}
|
||||
void ClearInternal() override {
|
||||
left_counts_.clear();
|
||||
right_counts_.clear();
|
||||
}
|
||||
|
||||
bool is_pure() const override { return first_two_classes_seen_.size() <= 1; }
|
||||
|
||||
void ClassificationAddLeftExample(int split, int64 int_label,
|
||||
float weight) override {
|
||||
left_counts_[split].accumulate(int_label, weight);
|
||||
}
|
||||
void ClassificationAddRightExample(int split, int64 int_label,
|
||||
float weight) override {
|
||||
right_counts_[split].accumulate(int_label, weight);
|
||||
}
|
||||
void ClassificationAddTotalExample(int64 int_label, float weight) override {
|
||||
if (is_pure()) {
|
||||
first_two_classes_seen_.insert(int_label);
|
||||
}
|
||||
}
|
||||
|
||||
float GiniScore(int split, float* left_sum, float* right_sum) const override;
|
||||
|
||||
float left_count(int split, int class_num) const override {
|
||||
return left_counts_[split].get_weight(class_num);
|
||||
}
|
||||
|
||||
float right_count(int split, int class_num) const override {
|
||||
return right_counts_[split].get_weight(class_num);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<FixedSizeClassStats> left_counts_;
|
||||
std::vector<FixedSizeClassStats> right_counts_;
|
||||
|
||||
// We keep track of the first two class labels seen, so we can tell if
|
||||
// the node is pure (= all of one class) or not.
|
||||
std::set<int> first_two_classes_seen_;
|
||||
};
|
||||
|
||||
// Tracks regression stats using least-squares minimization.
|
||||
class LeastSquaresRegressionGrowStats : public GrowStats {
|
||||
public:
|
||||
|
@ -29,6 +29,8 @@ using tensorflow::tensorforest::TestableInputTarget;
|
||||
using tensorflow::tensorforest::FertileSlot;
|
||||
using tensorflow::tensorforest::DenseClassificationGrowStats;
|
||||
using tensorflow::tensorforest::SparseClassificationGrowStats;
|
||||
using tensorflow::tensorforest::FixedSizeClassStats;
|
||||
using tensorflow::tensorforest::FixedSizeSparseClassificationGrowStats;
|
||||
using tensorflow::tensorforest::LeastSquaresRegressionGrowStats;
|
||||
using tensorflow::tensorforest::TensorForestParams;
|
||||
using tensorflow::tensorforest::SPLIT_FINISH_BASIC;
|
||||
@ -327,7 +329,6 @@ TEST(GrowStatsLeastSquaresRegressionTest, Basic) {
|
||||
ASSERT_EQ(serialized_again, serialized);
|
||||
}
|
||||
|
||||
|
||||
TEST(GrowStatsSparseClassificationTest, Basic) {
|
||||
TensorForestParams params;
|
||||
params.set_num_outputs(2);
|
||||
@ -360,5 +361,74 @@ TEST(GrowStatsSparseClassificationTest, Basic) {
|
||||
ASSERT_EQ(serialized_again, serialized);
|
||||
}
|
||||
|
||||
TEST(FixedSizeClassStats, Exact) {
|
||||
FixedSizeClassStats stats(10, 100);
|
||||
|
||||
stats.accumulate(1, 1.0);
|
||||
stats.accumulate(2, 2.0);
|
||||
stats.accumulate(3, 3.0);
|
||||
|
||||
EXPECT_EQ(stats.get_weight(1), 1.0);
|
||||
EXPECT_EQ(stats.get_weight(2), 2.0);
|
||||
EXPECT_EQ(stats.get_weight(3), 3.0);
|
||||
|
||||
float sum;
|
||||
float square;
|
||||
stats.set_sum_and_square(&sum, &square);
|
||||
|
||||
EXPECT_EQ(sum, 6.0);
|
||||
EXPECT_EQ(square, 14.0);
|
||||
}
|
||||
|
||||
TEST(FixedSizeClassStats, Approximate) {
|
||||
FixedSizeClassStats stats(5, 10);
|
||||
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
stats.accumulate(i, i * 1.0);
|
||||
}
|
||||
|
||||
// We should be off by no more than *half* of the least weight
|
||||
// in the class_weights_, which is 7.
|
||||
float tolerance = 3.5;
|
||||
for (int i = 1; i <= 10; i++) {
|
||||
float diff = stats.get_weight(i) - i * 1.0;
|
||||
EXPECT_LE(diff, tolerance);
|
||||
EXPECT_GE(diff, -tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GrowStatsFixedSizeSparseClassificationTest, Basic) {
|
||||
TensorForestParams params;
|
||||
params.set_num_outputs(2);
|
||||
params.set_num_classes_to_track(5);
|
||||
params.mutable_split_after_samples()->set_constant_value(2);
|
||||
params.mutable_num_splits_to_consider()->set_constant_value(2);
|
||||
std::unique_ptr<FixedSizeSparseClassificationGrowStats> stat(
|
||||
new FixedSizeSparseClassificationGrowStats(params, 1));
|
||||
stat->Initialize();
|
||||
|
||||
std::vector<float> labels = {100, 1000, 1};
|
||||
std::vector<float> weights = {2.3, 20.3, 1.1};
|
||||
std::unique_ptr<TestableInputTarget> target(
|
||||
new TestableInputTarget(labels, weights, 1));
|
||||
std::vector<int> branches = {1, 0, 1, 1, 0, 0};
|
||||
|
||||
RunBatch(stat.get(), target.get());
|
||||
CHECK(stat->IsFinished());
|
||||
|
||||
FertileSlot slot;
|
||||
stat->PackToProto(&slot);
|
||||
|
||||
string serialized = slot.DebugString();
|
||||
|
||||
std::unique_ptr<FixedSizeSparseClassificationGrowStats> new_stat(
|
||||
new FixedSizeSparseClassificationGrowStats(params, 1));
|
||||
new_stat->ExtractFromProto(slot);
|
||||
FertileSlot second_one;
|
||||
new_stat->PackToProto(&second_one);
|
||||
string serialized_again = second_one.DebugString();
|
||||
ASSERT_EQ(serialized_again, serialized);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -55,6 +55,10 @@ std::unique_ptr<GrowStats> SplitCollectionOperator::CreateGrowStats(
|
||||
return std::unique_ptr<GrowStats>(new LeastSquaresRegressionGrowStats(
|
||||
params_, depth));
|
||||
|
||||
case STATS_FIXED_SIZE_SPARSE_GINI:
|
||||
return std::unique_ptr<GrowStats>(
|
||||
new FixedSizeSparseClassificationGrowStats(params_, depth));
|
||||
|
||||
default:
|
||||
LOG(ERROR) << "Unknown grow stats type: " << params_.stats_type();
|
||||
return nullptr;
|
||||
|
@ -20,7 +20,9 @@ enum StatsModelType {
|
||||
STATS_DENSE_GINI = 0;
|
||||
STATS_SPARSE_GINI = 1;
|
||||
STATS_LEAST_SQUARES_REGRESSION = 2;
|
||||
// STATS_SPARSE_THEN_DENSE_GINI is deprecated and no longer supported.
|
||||
STATS_SPARSE_THEN_DENSE_GINI = 3;
|
||||
STATS_FIXED_SIZE_SPARSE_GINI = 4;
|
||||
}
|
||||
|
||||
// Allows selection of operations on the collection of split candidates.
|
||||
@ -145,4 +147,8 @@ message TensorForestParams {
|
||||
// --------- Parameters for experimental features ---------------------- //
|
||||
string graph_dir = 16;
|
||||
int32 num_select_features = 17;
|
||||
|
||||
// When using a FixedSizeSparseClassificationGrowStats, keep track of
|
||||
// this many classes.
|
||||
int32 num_classes_to_track = 24;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user