Add option to TensorForest to initialize splits with the average of 2 examples.
PiperOrigin-RevId: 161122345
This commit is contained in:
parent
87d86dbbf4
commit
75f03e2d50
@ -36,7 +36,7 @@ void FertileStatsResource::AddExampleToStatsAndInitialize(
|
||||
// the top but gradually becomes less of an issue as the tree grows.
|
||||
for (int example : examples) {
|
||||
collection_op_->CreateAndInitializeCandidateWithExample(
|
||||
input_data, example, node_id);
|
||||
input_data, target, example, node_id);
|
||||
if (collection_op_->IsInitialized(node_id)) {
|
||||
break;
|
||||
}
|
||||
|
@ -96,8 +96,8 @@ void GraphRunnerSplitCollectionOperator::AddExample(
|
||||
|
||||
void GraphRunnerSplitCollectionOperator::
|
||||
CreateAndInitializeCandidateWithExample(
|
||||
const std::unique_ptr<TensorDataSet>& input_data, int example,
|
||||
int32 node_id) const {
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example, int32 node_id) const {
|
||||
auto* slot = stats_.at(node_id).get();
|
||||
int cand_num = slot->num_splits();
|
||||
const int64 unique_id = UniqueId(node_id, cand_num);
|
||||
@ -125,7 +125,7 @@ void GraphRunnerSplitCollectionOperator::
|
||||
}
|
||||
}
|
||||
|
||||
slot->AddSplit(split);
|
||||
slot->AddSplit(split, input_data, target, example);
|
||||
|
||||
runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split));
|
||||
runners_[unique_id]->Init();
|
||||
|
@ -56,8 +56,8 @@ class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator {
|
||||
|
||||
// Create a new candidate and initialize it with the given example.
|
||||
void CreateAndInitializeCandidateWithExample(
|
||||
const std::unique_ptr<TensorDataSet>& input_data, int example,
|
||||
int32 node_id) const override;
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example, int32 node_id) const override;
|
||||
|
||||
bool BestSplit(int32 node_id, SplitCandidate* best,
|
||||
int32* depth) const override;
|
||||
|
@ -38,11 +38,23 @@ GrowStats::GrowStats(const TensorForestParams& params, int32 depth)
|
||||
ResolveParam(params.num_splits_to_consider(), depth)),
|
||||
num_outputs_(params.num_outputs()) {}
|
||||
|
||||
void GrowStats::AddSplit(const decision_trees::BinaryNode& split) {
|
||||
splits_.push_back(split);
|
||||
evaluators_.emplace_back(
|
||||
CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
|
||||
AddSplitStats();
|
||||
void GrowStats::AddSplit(const decision_trees::BinaryNode& split,
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example) {
|
||||
// It's possible that the split collection calls AddSplit, but we actually
|
||||
// have all the splits we need and are just waiting for them to be fully
|
||||
// initialized.
|
||||
if (splits_.size() < num_splits_to_consider_) {
|
||||
splits_.push_back(split);
|
||||
evaluators_.emplace_back(
|
||||
CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
|
||||
AddSplitStats(target, example);
|
||||
}
|
||||
|
||||
if (input_data != nullptr && target != nullptr &&
|
||||
params_.initialize_average_splits()) {
|
||||
AdditionalInitializationExample(input_data, target, example);
|
||||
}
|
||||
}
|
||||
|
||||
void GrowStats::RemoveSplit(int split_num) {
|
||||
@ -118,6 +130,34 @@ ClassificationStats::ClassificationStats(const TensorForestParams& params,
|
||||
new random::SimplePhilox(single_rand_.get()));
|
||||
}
|
||||
|
||||
void ClassificationStats::AdditionalInitializationExample(
|
||||
const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
|
||||
int example) {
|
||||
const int32 new_target = target->GetTargetAsClassIndex(example, 0);
|
||||
std::unordered_set<int> to_erase;
|
||||
for (auto it = half_initialized_splits_.begin();
|
||||
it != half_initialized_splits_.end(); ++it) {
|
||||
if (it->second != new_target) {
|
||||
auto& split = splits_[it->first];
|
||||
if (split.has_inequality_left_child_test()) {
|
||||
auto& test = split.inequality_left_child_test();
|
||||
auto* thresh =
|
||||
split.mutable_inequality_left_child_test()->mutable_threshold();
|
||||
if (test.has_feature_id()) {
|
||||
const float val =
|
||||
input_data->GetExampleValue(example, test.feature_id());
|
||||
thresh->set_float_value((thresh->float_value() + val) / 2);
|
||||
}
|
||||
}
|
||||
to_erase.insert(it->first);
|
||||
}
|
||||
}
|
||||
|
||||
for (const int split_id : to_erase) {
|
||||
half_initialized_splits_.erase(split_id);
|
||||
}
|
||||
}
|
||||
|
||||
bool ClassificationStats::IsFinished() const {
|
||||
bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1;
|
||||
return basic || finish_early_;
|
||||
@ -353,7 +393,7 @@ void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
|
||||
// Candidate counts and splits.
|
||||
int split_num = 0;
|
||||
for (const auto& cand : slot.candidates()) {
|
||||
AddSplit(cand.split());
|
||||
AddSplit(cand.split(), nullptr, nullptr, -1);
|
||||
const auto& left_stats = cand.left_stats().classification().dense_counts();
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
const float val = left_stats.value(i).float_value();
|
||||
@ -474,7 +514,7 @@ void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
|
||||
// Candidate counts and splits.
|
||||
int split_num = 0;
|
||||
for (const auto& cand : slot.candidates()) {
|
||||
AddSplit(cand.split());
|
||||
AddSplit(cand.split(), nullptr, nullptr, -1);
|
||||
const auto& left_stats = cand.left_stats().classification().sparse_counts();
|
||||
for (auto const& entry : left_stats.sparse_value()) {
|
||||
const float val = entry.second.float_value();
|
||||
@ -622,7 +662,7 @@ void LeastSquaresRegressionGrowStats::ExtractFromProto(
|
||||
// Candidate counts and splits.
|
||||
int split_num = 0;
|
||||
for (const auto& cand : slot.candidates()) {
|
||||
AddSplit(cand.split());
|
||||
AddSplit(cand.split(), nullptr, nullptr, -1);
|
||||
const auto& sums = cand.left_stats().regression().mean_output();
|
||||
const auto& squares = cand.left_stats().regression().mean_output_squares();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
|
@ -65,7 +65,12 @@ class GrowStats {
|
||||
virtual void PackToProto(FertileSlot* slot) const = 0;
|
||||
|
||||
// Add split to the list of candidate splits.
|
||||
void AddSplit(const decision_trees::BinaryNode& split);
|
||||
void AddSplit(const decision_trees::BinaryNode& split,
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example);
|
||||
virtual void AdditionalInitializationExample(
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example) {}
|
||||
void RemoveSplit(int split_num);
|
||||
|
||||
int num_splits() const {
|
||||
@ -76,7 +81,7 @@ class GrowStats {
|
||||
return weight_sum_;
|
||||
}
|
||||
|
||||
bool IsInitialized() const {
|
||||
virtual bool IsInitialized() const {
|
||||
return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_;
|
||||
}
|
||||
|
||||
@ -88,7 +93,7 @@ class GrowStats {
|
||||
GrowStats(const TensorForestParams& params, int32 depth);
|
||||
|
||||
// Function called by AddSplit for subclasses to initialize stats for a split.
|
||||
virtual void AddSplitStats() = 0;
|
||||
virtual void AddSplitStats(const InputTarget* target, int example) = 0;
|
||||
|
||||
virtual void RemoveSplitStats(int split_num) = 0;
|
||||
|
||||
@ -134,7 +139,7 @@ class SimpleStats : public GrowStats {
|
||||
}
|
||||
|
||||
protected:
|
||||
void AddSplitStats() override {}
|
||||
void AddSplitStats(const InputTarget* target, int example) override {}
|
||||
void RemoveSplitStats(int split_num) override {}
|
||||
void ClearInternal() override {}
|
||||
};
|
||||
@ -175,6 +180,15 @@ class ClassificationStats : public GrowStats {
|
||||
void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example) override;
|
||||
|
||||
void AdditionalInitializationExample(
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example) override;
|
||||
|
||||
bool IsInitialized() const override {
|
||||
return weight_sum_ > 0 || (splits_.size() == num_splits_to_consider_ &&
|
||||
half_initialized_splits_.empty());
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual float GiniScore(int split, float* left_sum,
|
||||
float* right_sum) const = 0;
|
||||
@ -189,11 +203,17 @@ class ClassificationStats : public GrowStats {
|
||||
virtual void ClassificationAddSplitStats() = 0;
|
||||
virtual void ClassificationRemoveSplitStats(int split) = 0;
|
||||
|
||||
void AddSplitStats() override {
|
||||
void AddSplitStats(const InputTarget* target, int example) override {
|
||||
if (left_gini_ != nullptr) {
|
||||
left_gini_->add_split();
|
||||
right_gini_->add_split();
|
||||
}
|
||||
if (params_.initialize_average_splits()) {
|
||||
if (splits_[splits_.size() - 1].has_inequality_left_child_test()) {
|
||||
half_initialized_splits_[splits_.size() - 1] =
|
||||
target->GetTargetAsClassIndex(example, 0);
|
||||
}
|
||||
}
|
||||
ClassificationAddSplitStats();
|
||||
}
|
||||
void RemoveSplitStats(int split) override {
|
||||
@ -262,6 +282,9 @@ class ClassificationStats : public GrowStats {
|
||||
|
||||
std::unique_ptr<RunningGiniScores> left_gini_;
|
||||
std::unique_ptr<RunningGiniScores> right_gini_;
|
||||
|
||||
// Stores split number -> class that was first seen.
|
||||
std::unordered_map<int, int32> half_initialized_splits_;
|
||||
};
|
||||
|
||||
// Tracks classification stats by storing class counts densely.
|
||||
@ -413,7 +436,7 @@ class LeastSquaresRegressionGrowStats : public GrowStats {
|
||||
// Returns the variance of split.
|
||||
float SplitVariance(int split) const;
|
||||
|
||||
void AddSplitStats() override {
|
||||
void AddSplitStats(const InputTarget* target, int example) override {
|
||||
left_sums_.resize(num_outputs_ * num_splits());
|
||||
left_squares_.resize(num_outputs_ * num_splits());
|
||||
left_counts_.push_back(0);
|
||||
|
@ -52,13 +52,13 @@ BinaryNode MakeSplit(const string& feat, float val) {
|
||||
|
||||
void RunBatch(GrowStats* stats,
|
||||
const TestableInputTarget* target) {
|
||||
stats->AddSplit(MakeSplit("0", 10.0));
|
||||
stats->AddSplit(MakeSplit("1", 4.0));
|
||||
|
||||
std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
|
||||
new tensorflow::tensorforest::TestableDataSet(
|
||||
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2));
|
||||
|
||||
stats->AddSplit(MakeSplit("0", 10.0), dataset, target, 0);
|
||||
stats->AddSplit(MakeSplit("1", 4.0), dataset, target, 0);
|
||||
|
||||
for (int i = 0; i < target->NumItems(); ++i) {
|
||||
stats->AddExample(dataset, target, i);
|
||||
}
|
||||
@ -225,11 +225,6 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) {
|
||||
params.mutable_pruning_type()->mutable_prune_every_samples()
|
||||
->set_constant_value(1);
|
||||
|
||||
DenseClassificationGrowStats stats(params, 1);
|
||||
stats.Initialize();
|
||||
stats.AddSplit(MakeSplit("0", 0.0));
|
||||
stats.AddSplit(MakeSplit("1", 0.0));
|
||||
|
||||
// On each iteration, we add two examples, one of class 0 and one
|
||||
// of class 1. Split #0 classifies them perfectly, while split #1
|
||||
// sends them both to the left.
|
||||
@ -240,6 +235,11 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) {
|
||||
new tensorflow::tensorforest::TestableDataSet(
|
||||
{-1.0, -1.0, 1.0, -1.0}, 2));
|
||||
|
||||
DenseClassificationGrowStats stats(params, 1);
|
||||
stats.Initialize();
|
||||
stats.AddSplit(MakeSplit("0", 0.0), dataset, &target, 0);
|
||||
stats.AddSplit(MakeSplit("1", 0.0), dataset, &target, 0);
|
||||
|
||||
// Math time!
|
||||
// After 2n samples,
|
||||
// split 0 has smoothed counts (n+1,1);(1,n+1) and
|
||||
|
@ -100,8 +100,8 @@ bool SplitCollectionOperator::IsInitialized(int32 node_id) const {
|
||||
}
|
||||
|
||||
void SplitCollectionOperator::CreateAndInitializeCandidateWithExample(
|
||||
const std::unique_ptr<TensorDataSet>& input_data, int example,
|
||||
int32 node_id) const {
|
||||
const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
|
||||
int example, int32 node_id) const {
|
||||
// Assumes split_initializations_per_input == 1.
|
||||
decision_trees::BinaryNode split;
|
||||
float bias;
|
||||
@ -124,7 +124,7 @@ void SplitCollectionOperator::CreateAndInitializeCandidateWithExample(
|
||||
LOG(ERROR) << "Unknown feature type " << type << ", not sure which "
|
||||
<< "node type to use.";
|
||||
}
|
||||
stats_.at(node_id)->AddSplit(split);
|
||||
stats_.at(node_id)->AddSplit(split, input_data, target, example);
|
||||
}
|
||||
|
||||
bool SplitCollectionOperator::BestSplit(int32 node_id,
|
||||
@ -134,6 +134,5 @@ bool SplitCollectionOperator::BestSplit(int32 node_id,
|
||||
*depth = slot->depth();
|
||||
return slot->BestSplit(best);
|
||||
}
|
||||
|
||||
} // namespace tensorforest
|
||||
} // namespace tensorflow
|
||||
|
@ -56,8 +56,8 @@ class SplitCollectionOperator {
|
||||
|
||||
// Create a new candidate and initialize it with the given example.
|
||||
virtual void CreateAndInitializeCandidateWithExample(
|
||||
const std::unique_ptr<TensorDataSet>& input_data, int example,
|
||||
int32 node_id) const;
|
||||
const std::unique_ptr<TensorDataSet>& input_data,
|
||||
const InputTarget* target, int example, int32 node_id) const;
|
||||
|
||||
// Create a new GrowStats for the given node id and initialize it.
|
||||
virtual void InitializeSlot(int32 node_id, int32 depth);
|
||||
|
@ -130,6 +130,7 @@ message TensorForestParams {
|
||||
bool collate_examples = 10;
|
||||
bool checkpoint_stats = 11;
|
||||
bool use_running_stats_method = 20;
|
||||
bool initialize_average_splits = 22;
|
||||
|
||||
// Number of classes (classification) or targets (regression)
|
||||
int32 num_outputs = 12;
|
||||
|
@ -90,6 +90,7 @@ def build_params_proto(params):
|
||||
proto.collate_examples = params.v4_collate_examples
|
||||
proto.checkpoint_stats = params.v4_checkpoint_stats
|
||||
proto.use_running_stats_method = params.v4_use_running_stats_method
|
||||
proto.initialize_average_splits = params.v4_initialize_average_splits
|
||||
|
||||
if params.v4_prune_every_samples:
|
||||
text_format.Merge(params.v4_prune_every_samples,
|
||||
@ -174,6 +175,8 @@ class V4ForestHParams(object):
|
||||
self.v4_checkpoint_stats = getattr(self, 'v4_checkpoint_stats', False)
|
||||
self.v4_use_running_stats_method = getattr(
|
||||
self, 'v4_use_running_stats_method', False)
|
||||
self.v4_initialize_average_splits = getattr(
|
||||
self, 'v4_initialize_average_splits', False)
|
||||
|
||||
self.v4_param_file = getattr(self, 'v4_param_file', None)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user