From 75f03e2d509d016021f8508555f9ab96af2c7cfe Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Jul 2017 13:29:22 -0700 Subject: [PATCH] Add option to TensorForest to initialize splits with the average of 2 examples. PiperOrigin-RevId: 161122345 --- .../kernels/v4/fertile-stats-resource.cc | 2 +- .../kernels/v4/graph_collection_operator.cc | 6 +- .../kernels/v4/graph_collection_operator.h | 4 +- .../tensor_forest/kernels/v4/grow_stats.cc | 56 ++++++++++++++++--- .../tensor_forest/kernels/v4/grow_stats.h | 35 ++++++++++-- .../kernels/v4/grow_stats_test.cc | 16 +++--- .../kernels/v4/split_collection_operators.cc | 7 +-- .../kernels/v4/split_collection_operators.h | 4 +- .../proto/tensor_forest_params.proto | 1 + .../tensor_forest/python/tensor_forest_v4.py | 3 + 10 files changed, 100 insertions(+), 34 deletions(-) diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc index 9f5d9485143..5c1b7454ae6 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc @@ -36,7 +36,7 @@ void FertileStatsResource::AddExampleToStatsAndInitialize( // the top but gradually becomes less of an issue as the tree grows. for (int example : examples) { collection_op_->CreateAndInitializeCandidateWithExample( - input_data, example, node_id); + input_data, target, example, node_id); if (collection_op_->IsInitialized(node_id)) { break; } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc index 2c925b5dd77..c7faea0aef1 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc @@ -96,8 +96,8 @@ void GraphRunnerSplitCollectionOperator::AddExample( void GraphRunnerSplitCollectionOperator:: CreateAndInitializeCandidateWithExample( - const std::unique_ptr& input_data, int example, - int32 node_id) const { + const std::unique_ptr& input_data, + const InputTarget* target, int example, int32 node_id) const { auto* slot = stats_.at(node_id).get(); int cand_num = slot->num_splits(); const int64 unique_id = UniqueId(node_id, cand_num); @@ -125,7 +125,7 @@ void GraphRunnerSplitCollectionOperator:: } } - slot->AddSplit(split); + slot->AddSplit(split, input_data, target, example); runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split)); runners_[unique_id]->Init(); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h index 9b18e3e9694..2ae3a79b3dd 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h @@ -56,8 +56,8 @@ class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator { // Create a new candidate and initialize it with the given example. void CreateAndInitializeCandidateWithExample( - const std::unique_ptr& input_data, int example, - int32 node_id) const override; + const std::unique_ptr& input_data, + const InputTarget* target, int example, int32 node_id) const override; bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth) const override; diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc index 226376c571b..81b4534f10e 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -38,11 +38,23 @@ GrowStats::GrowStats(const TensorForestParams& params, int32 depth) ResolveParam(params.num_splits_to_consider(), depth)), num_outputs_(params.num_outputs()) {} -void GrowStats::AddSplit(const decision_trees::BinaryNode& split) { - splits_.push_back(split); - evaluators_.emplace_back( - CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX)); - AddSplitStats(); +void GrowStats::AddSplit(const decision_trees::BinaryNode& split, + const std::unique_ptr& input_data, + const InputTarget* target, int example) { + // It's possible that the split collection calls AddSplit, but we actually + // have all the splits we need and are just waiting for them to be fully + // initialized. + if (splits_.size() < num_splits_to_consider_) { + splits_.push_back(split); + evaluators_.emplace_back( + CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX)); + AddSplitStats(target, example); + } + + if (input_data != nullptr && target != nullptr && + params_.initialize_average_splits()) { + AdditionalInitializationExample(input_data, target, example); + } } void GrowStats::RemoveSplit(int split_num) { @@ -118,6 +130,34 @@ ClassificationStats::ClassificationStats(const TensorForestParams& params, new random::SimplePhilox(single_rand_.get())); } +void ClassificationStats::AdditionalInitializationExample( + const std::unique_ptr& input_data, const InputTarget* target, + int example) { + const int32 new_target = target->GetTargetAsClassIndex(example, 0); + std::unordered_set to_erase; + for (auto it = half_initialized_splits_.begin(); + it != half_initialized_splits_.end(); ++it) { + if (it->second != new_target) { + auto& split = splits_[it->first]; + if (split.has_inequality_left_child_test()) { + auto& test = split.inequality_left_child_test(); + auto* thresh = + split.mutable_inequality_left_child_test()->mutable_threshold(); + if (test.has_feature_id()) { + const float val = + input_data->GetExampleValue(example, test.feature_id()); + thresh->set_float_value((thresh->float_value() + val) / 2); + } + } + to_erase.insert(it->first); + } + } + + for (const int split_id : to_erase) { + half_initialized_splits_.erase(split_id); + } +} + bool ClassificationStats::IsFinished() const { bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1; return basic || finish_early_; @@ -353,7 +393,7 @@ void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { // Candidate counts and splits. int split_num = 0; for (const auto& cand : slot.candidates()) { - AddSplit(cand.split()); + AddSplit(cand.split(), nullptr, nullptr, -1); const auto& left_stats = cand.left_stats().classification().dense_counts(); for (int i = 0; i < num_classes; ++i) { const float val = left_stats.value(i).float_value(); @@ -474,7 +514,7 @@ void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { // Candidate counts and splits. int split_num = 0; for (const auto& cand : slot.candidates()) { - AddSplit(cand.split()); + AddSplit(cand.split(), nullptr, nullptr, -1); const auto& left_stats = cand.left_stats().classification().sparse_counts(); for (auto const& entry : left_stats.sparse_value()) { const float val = entry.second.float_value(); @@ -622,7 +662,7 @@ void LeastSquaresRegressionGrowStats::ExtractFromProto( // Candidate counts and splits. int split_num = 0; for (const auto& cand : slot.candidates()) { - AddSplit(cand.split()); + AddSplit(cand.split(), nullptr, nullptr, -1); const auto& sums = cand.left_stats().regression().mean_output(); const auto& squares = cand.left_stats().regression().mean_output_squares(); for (int i = 0; i < num_outputs; ++i) { diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h index 8d32b4961b1..6702b81c791 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -65,7 +65,12 @@ class GrowStats { virtual void PackToProto(FertileSlot* slot) const = 0; // Add split to the list of candidate splits. - void AddSplit(const decision_trees::BinaryNode& split); + void AddSplit(const decision_trees::BinaryNode& split, + const std::unique_ptr& input_data, + const InputTarget* target, int example); + virtual void AdditionalInitializationExample( + const std::unique_ptr& input_data, + const InputTarget* target, int example) {} void RemoveSplit(int split_num); int num_splits() const { @@ -76,7 +81,7 @@ class GrowStats { return weight_sum_; } - bool IsInitialized() const { + virtual bool IsInitialized() const { return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_; } @@ -88,7 +93,7 @@ class GrowStats { GrowStats(const TensorForestParams& params, int32 depth); // Function called by AddSplit for subclasses to initialize stats for a split. - virtual void AddSplitStats() = 0; + virtual void AddSplitStats(const InputTarget* target, int example) = 0; virtual void RemoveSplitStats(int split_num) = 0; @@ -134,7 +139,7 @@ class SimpleStats : public GrowStats { } protected: - void AddSplitStats() override {} + void AddSplitStats(const InputTarget* target, int example) override {} void RemoveSplitStats(int split_num) override {} void ClearInternal() override {} }; @@ -175,6 +180,15 @@ class ClassificationStats : public GrowStats { void AddExample(const std::unique_ptr& input_data, const InputTarget* target, int example) override; + void AdditionalInitializationExample( + const std::unique_ptr& input_data, + const InputTarget* target, int example) override; + + bool IsInitialized() const override { + return weight_sum_ > 0 || (splits_.size() == num_splits_to_consider_ && + half_initialized_splits_.empty()); + } + protected: virtual float GiniScore(int split, float* left_sum, float* right_sum) const = 0; @@ -189,11 +203,17 @@ class ClassificationStats : public GrowStats { virtual void ClassificationAddSplitStats() = 0; virtual void ClassificationRemoveSplitStats(int split) = 0; - void AddSplitStats() override { + void AddSplitStats(const InputTarget* target, int example) override { if (left_gini_ != nullptr) { left_gini_->add_split(); right_gini_->add_split(); } + if (params_.initialize_average_splits()) { + if (splits_[splits_.size() - 1].has_inequality_left_child_test()) { + half_initialized_splits_[splits_.size() - 1] = + target->GetTargetAsClassIndex(example, 0); + } + } ClassificationAddSplitStats(); } void RemoveSplitStats(int split) override { @@ -262,6 +282,9 @@ class ClassificationStats : public GrowStats { std::unique_ptr left_gini_; std::unique_ptr right_gini_; + + // Stores split number -> class that was first seen. + std::unordered_map half_initialized_splits_; }; // Tracks classification stats by storing class counts densely. @@ -413,7 +436,7 @@ class LeastSquaresRegressionGrowStats : public GrowStats { // Returns the variance of split. float SplitVariance(int split) const; - void AddSplitStats() override { + void AddSplitStats(const InputTarget* target, int example) override { left_sums_.resize(num_outputs_ * num_splits()); left_squares_.resize(num_outputs_ * num_splits()); left_counts_.push_back(0); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc index 7cd16d222ad..fa959e8373a 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc @@ -52,13 +52,13 @@ BinaryNode MakeSplit(const string& feat, float val) { void RunBatch(GrowStats* stats, const TestableInputTarget* target) { - stats->AddSplit(MakeSplit("0", 10.0)); - stats->AddSplit(MakeSplit("1", 4.0)); - std::unique_ptr dataset( new tensorflow::tensorforest::TestableDataSet( {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2)); + stats->AddSplit(MakeSplit("0", 10.0), dataset, target, 0); + stats->AddSplit(MakeSplit("1", 4.0), dataset, target, 0); + for (int i = 0; i < target->NumItems(); ++i) { stats->AddExample(dataset, target, i); } @@ -225,11 +225,6 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { params.mutable_pruning_type()->mutable_prune_every_samples() ->set_constant_value(1); - DenseClassificationGrowStats stats(params, 1); - stats.Initialize(); - stats.AddSplit(MakeSplit("0", 0.0)); - stats.AddSplit(MakeSplit("1", 0.0)); - // On each iteration, we add two examples, one of class 0 and one // of class 1. Split #0 classifies them perfectly, while split #1 // sends them both to the left. @@ -240,6 +235,11 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { new tensorflow::tensorforest::TestableDataSet( {-1.0, -1.0, 1.0, -1.0}, 2)); + DenseClassificationGrowStats stats(params, 1); + stats.Initialize(); + stats.AddSplit(MakeSplit("0", 0.0), dataset, &target, 0); + stats.AddSplit(MakeSplit("1", 0.0), dataset, &target, 0); + // Math time! // After 2n samples, // split 0 has smoothed counts (n+1,1);(1,n+1) and diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index c207c0859d8..632408fd718 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -100,8 +100,8 @@ bool SplitCollectionOperator::IsInitialized(int32 node_id) const { } void SplitCollectionOperator::CreateAndInitializeCandidateWithExample( - const std::unique_ptr& input_data, int example, - int32 node_id) const { + const std::unique_ptr& input_data, const InputTarget* target, + int example, int32 node_id) const { // Assumes split_initializations_per_input == 1. decision_trees::BinaryNode split; float bias; @@ -124,7 +124,7 @@ void SplitCollectionOperator::CreateAndInitializeCandidateWithExample( LOG(ERROR) << "Unknown feature type " << type << ", not sure which " << "node type to use."; } - stats_.at(node_id)->AddSplit(split); + stats_.at(node_id)->AddSplit(split, input_data, target, example); } bool SplitCollectionOperator::BestSplit(int32 node_id, @@ -134,6 +134,5 @@ bool SplitCollectionOperator::BestSplit(int32 node_id, *depth = slot->depth(); return slot->BestSplit(best); } - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h index 81d820a6b28..6990e82678b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h @@ -56,8 +56,8 @@ class SplitCollectionOperator { // Create a new candidate and initialize it with the given example. virtual void CreateAndInitializeCandidateWithExample( - const std::unique_ptr& input_data, int example, - int32 node_id) const; + const std::unique_ptr& input_data, + const InputTarget* target, int example, int32 node_id) const; // Create a new GrowStats for the given node id and initialize it. virtual void InitializeSlot(int32 node_id, int32 depth); diff --git a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto index 49b19e0b623..58c5b9bbe73 100644 --- a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto +++ b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto @@ -130,6 +130,7 @@ message TensorForestParams { bool collate_examples = 10; bool checkpoint_stats = 11; bool use_running_stats_method = 20; + bool initialize_average_splits = 22; // Number of classes (classification) or targets (regression) int32 num_outputs = 12; diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py index 5773b5d729a..7e6f00a13df 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py @@ -90,6 +90,7 @@ def build_params_proto(params): proto.collate_examples = params.v4_collate_examples proto.checkpoint_stats = params.v4_checkpoint_stats proto.use_running_stats_method = params.v4_use_running_stats_method + proto.initialize_average_splits = params.v4_initialize_average_splits if params.v4_prune_every_samples: text_format.Merge(params.v4_prune_every_samples, @@ -174,6 +175,8 @@ class V4ForestHParams(object): self.v4_checkpoint_stats = getattr(self, 'v4_checkpoint_stats', False) self.v4_use_running_stats_method = getattr( self, 'v4_use_running_stats_method', False) + self.v4_initialize_average_splits = getattr( + self, 'v4_initialize_average_splits', False) self.v4_param_file = getattr(self, 'v4_param_file', None)