Add option to TensorForest to initialize splits with the average of 2 examples.

PiperOrigin-RevId: 161122345
This commit is contained in:
A. Unique TensorFlower 2017-07-06 13:29:22 -07:00 committed by TensorFlower Gardener
parent 87d86dbbf4
commit 75f03e2d50
10 changed files with 100 additions and 34 deletions

View File

@ -36,7 +36,7 @@ void FertileStatsResource::AddExampleToStatsAndInitialize(
// the top but gradually becomes less of an issue as the tree grows.
for (int example : examples) {
collection_op_->CreateAndInitializeCandidateWithExample(
input_data, example, node_id);
input_data, target, example, node_id);
if (collection_op_->IsInitialized(node_id)) {
break;
}

View File

@ -96,8 +96,8 @@ void GraphRunnerSplitCollectionOperator::AddExample(
void GraphRunnerSplitCollectionOperator::
CreateAndInitializeCandidateWithExample(
const std::unique_ptr<TensorDataSet>& input_data, int example,
int32 node_id) const {
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example, int32 node_id) const {
auto* slot = stats_.at(node_id).get();
int cand_num = slot->num_splits();
const int64 unique_id = UniqueId(node_id, cand_num);
@ -125,7 +125,7 @@ void GraphRunnerSplitCollectionOperator::
}
}
slot->AddSplit(split);
slot->AddSplit(split, input_data, target, example);
runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split));
runners_[unique_id]->Init();

View File

@ -56,8 +56,8 @@ class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator {
// Create a new candidate and initialize it with the given example.
void CreateAndInitializeCandidateWithExample(
const std::unique_ptr<TensorDataSet>& input_data, int example,
int32 node_id) const override;
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example, int32 node_id) const override;
bool BestSplit(int32 node_id, SplitCandidate* best,
int32* depth) const override;

View File

@ -38,11 +38,23 @@ GrowStats::GrowStats(const TensorForestParams& params, int32 depth)
ResolveParam(params.num_splits_to_consider(), depth)),
num_outputs_(params.num_outputs()) {}
void GrowStats::AddSplit(const decision_trees::BinaryNode& split) {
splits_.push_back(split);
evaluators_.emplace_back(
CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
AddSplitStats();
void GrowStats::AddSplit(const decision_trees::BinaryNode& split,
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example) {
// It's possible that the split collection calls AddSplit, but we actually
// have all the splits we need and are just waiting for them to be fully
// initialized.
if (splits_.size() < num_splits_to_consider_) {
splits_.push_back(split);
evaluators_.emplace_back(
CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
AddSplitStats(target, example);
}
if (input_data != nullptr && target != nullptr &&
params_.initialize_average_splits()) {
AdditionalInitializationExample(input_data, target, example);
}
}
void GrowStats::RemoveSplit(int split_num) {
@ -118,6 +130,34 @@ ClassificationStats::ClassificationStats(const TensorForestParams& params,
new random::SimplePhilox(single_rand_.get()));
}
void ClassificationStats::AdditionalInitializationExample(
const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
int example) {
const int32 new_target = target->GetTargetAsClassIndex(example, 0);
std::unordered_set<int> to_erase;
for (auto it = half_initialized_splits_.begin();
it != half_initialized_splits_.end(); ++it) {
if (it->second != new_target) {
auto& split = splits_[it->first];
if (split.has_inequality_left_child_test()) {
auto& test = split.inequality_left_child_test();
auto* thresh =
split.mutable_inequality_left_child_test()->mutable_threshold();
if (test.has_feature_id()) {
const float val =
input_data->GetExampleValue(example, test.feature_id());
thresh->set_float_value((thresh->float_value() + val) / 2);
}
}
to_erase.insert(it->first);
}
}
for (const int split_id : to_erase) {
half_initialized_splits_.erase(split_id);
}
}
bool ClassificationStats::IsFinished() const {
bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1;
return basic || finish_early_;
@ -353,7 +393,7 @@ void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
// Candidate counts and splits.
int split_num = 0;
for (const auto& cand : slot.candidates()) {
AddSplit(cand.split());
AddSplit(cand.split(), nullptr, nullptr, -1);
const auto& left_stats = cand.left_stats().classification().dense_counts();
for (int i = 0; i < num_classes; ++i) {
const float val = left_stats.value(i).float_value();
@ -474,7 +514,7 @@ void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
// Candidate counts and splits.
int split_num = 0;
for (const auto& cand : slot.candidates()) {
AddSplit(cand.split());
AddSplit(cand.split(), nullptr, nullptr, -1);
const auto& left_stats = cand.left_stats().classification().sparse_counts();
for (auto const& entry : left_stats.sparse_value()) {
const float val = entry.second.float_value();
@ -622,7 +662,7 @@ void LeastSquaresRegressionGrowStats::ExtractFromProto(
// Candidate counts and splits.
int split_num = 0;
for (const auto& cand : slot.candidates()) {
AddSplit(cand.split());
AddSplit(cand.split(), nullptr, nullptr, -1);
const auto& sums = cand.left_stats().regression().mean_output();
const auto& squares = cand.left_stats().regression().mean_output_squares();
for (int i = 0; i < num_outputs; ++i) {

View File

@ -65,7 +65,12 @@ class GrowStats {
virtual void PackToProto(FertileSlot* slot) const = 0;
// Add split to the list of candidate splits.
void AddSplit(const decision_trees::BinaryNode& split);
void AddSplit(const decision_trees::BinaryNode& split,
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example);
virtual void AdditionalInitializationExample(
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example) {}
void RemoveSplit(int split_num);
int num_splits() const {
@ -76,7 +81,7 @@ class GrowStats {
return weight_sum_;
}
bool IsInitialized() const {
virtual bool IsInitialized() const {
return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_;
}
@ -88,7 +93,7 @@ class GrowStats {
GrowStats(const TensorForestParams& params, int32 depth);
// Function called by AddSplit for subclasses to initialize stats for a split.
virtual void AddSplitStats() = 0;
virtual void AddSplitStats(const InputTarget* target, int example) = 0;
virtual void RemoveSplitStats(int split_num) = 0;
@ -134,7 +139,7 @@ class SimpleStats : public GrowStats {
}
protected:
void AddSplitStats() override {}
void AddSplitStats(const InputTarget* target, int example) override {}
void RemoveSplitStats(int split_num) override {}
void ClearInternal() override {}
};
@ -175,6 +180,15 @@ class ClassificationStats : public GrowStats {
void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example) override;
void AdditionalInitializationExample(
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example) override;
bool IsInitialized() const override {
return weight_sum_ > 0 || (splits_.size() == num_splits_to_consider_ &&
half_initialized_splits_.empty());
}
protected:
virtual float GiniScore(int split, float* left_sum,
float* right_sum) const = 0;
@ -189,11 +203,17 @@ class ClassificationStats : public GrowStats {
virtual void ClassificationAddSplitStats() = 0;
virtual void ClassificationRemoveSplitStats(int split) = 0;
void AddSplitStats() override {
void AddSplitStats(const InputTarget* target, int example) override {
if (left_gini_ != nullptr) {
left_gini_->add_split();
right_gini_->add_split();
}
if (params_.initialize_average_splits()) {
if (splits_[splits_.size() - 1].has_inequality_left_child_test()) {
half_initialized_splits_[splits_.size() - 1] =
target->GetTargetAsClassIndex(example, 0);
}
}
ClassificationAddSplitStats();
}
void RemoveSplitStats(int split) override {
@ -262,6 +282,9 @@ class ClassificationStats : public GrowStats {
std::unique_ptr<RunningGiniScores> left_gini_;
std::unique_ptr<RunningGiniScores> right_gini_;
// Stores split number -> class that was first seen.
std::unordered_map<int, int32> half_initialized_splits_;
};
// Tracks classification stats by storing class counts densely.
@ -413,7 +436,7 @@ class LeastSquaresRegressionGrowStats : public GrowStats {
// Returns the variance of split.
float SplitVariance(int split) const;
void AddSplitStats() override {
void AddSplitStats(const InputTarget* target, int example) override {
left_sums_.resize(num_outputs_ * num_splits());
left_squares_.resize(num_outputs_ * num_splits());
left_counts_.push_back(0);

View File

@ -52,13 +52,13 @@ BinaryNode MakeSplit(const string& feat, float val) {
void RunBatch(GrowStats* stats,
const TestableInputTarget* target) {
stats->AddSplit(MakeSplit("0", 10.0));
stats->AddSplit(MakeSplit("1", 4.0));
std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
new tensorflow::tensorforest::TestableDataSet(
{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2));
stats->AddSplit(MakeSplit("0", 10.0), dataset, target, 0);
stats->AddSplit(MakeSplit("1", 4.0), dataset, target, 0);
for (int i = 0; i < target->NumItems(); ++i) {
stats->AddExample(dataset, target, i);
}
@ -225,11 +225,6 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) {
params.mutable_pruning_type()->mutable_prune_every_samples()
->set_constant_value(1);
DenseClassificationGrowStats stats(params, 1);
stats.Initialize();
stats.AddSplit(MakeSplit("0", 0.0));
stats.AddSplit(MakeSplit("1", 0.0));
// On each iteration, we add two examples, one of class 0 and one
// of class 1. Split #0 classifies them perfectly, while split #1
// sends them both to the left.
@ -240,6 +235,11 @@ TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) {
new tensorflow::tensorforest::TestableDataSet(
{-1.0, -1.0, 1.0, -1.0}, 2));
DenseClassificationGrowStats stats(params, 1);
stats.Initialize();
stats.AddSplit(MakeSplit("0", 0.0), dataset, &target, 0);
stats.AddSplit(MakeSplit("1", 0.0), dataset, &target, 0);
// Math time!
// After 2n samples,
// split 0 has smoothed counts (n+1,1);(1,n+1) and

View File

@ -100,8 +100,8 @@ bool SplitCollectionOperator::IsInitialized(int32 node_id) const {
}
void SplitCollectionOperator::CreateAndInitializeCandidateWithExample(
const std::unique_ptr<TensorDataSet>& input_data, int example,
int32 node_id) const {
const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
int example, int32 node_id) const {
// Assumes split_initializations_per_input == 1.
decision_trees::BinaryNode split;
float bias;
@ -124,7 +124,7 @@ void SplitCollectionOperator::CreateAndInitializeCandidateWithExample(
LOG(ERROR) << "Unknown feature type " << type << ", not sure which "
<< "node type to use.";
}
stats_.at(node_id)->AddSplit(split);
stats_.at(node_id)->AddSplit(split, input_data, target, example);
}
bool SplitCollectionOperator::BestSplit(int32 node_id,
@ -134,6 +134,5 @@ bool SplitCollectionOperator::BestSplit(int32 node_id,
*depth = slot->depth();
return slot->BestSplit(best);
}
} // namespace tensorforest
} // namespace tensorflow

View File

@ -56,8 +56,8 @@ class SplitCollectionOperator {
// Create a new candidate and initialize it with the given example.
virtual void CreateAndInitializeCandidateWithExample(
const std::unique_ptr<TensorDataSet>& input_data, int example,
int32 node_id) const;
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, int example, int32 node_id) const;
// Create a new GrowStats for the given node id and initialize it.
virtual void InitializeSlot(int32 node_id, int32 depth);

View File

@ -130,6 +130,7 @@ message TensorForestParams {
bool collate_examples = 10;
bool checkpoint_stats = 11;
bool use_running_stats_method = 20;
bool initialize_average_splits = 22;
// Number of classes (classification) or targets (regression)
int32 num_outputs = 12;

View File

@ -90,6 +90,7 @@ def build_params_proto(params):
proto.collate_examples = params.v4_collate_examples
proto.checkpoint_stats = params.v4_checkpoint_stats
proto.use_running_stats_method = params.v4_use_running_stats_method
proto.initialize_average_splits = params.v4_initialize_average_splits
if params.v4_prune_every_samples:
text_format.Merge(params.v4_prune_every_samples,
@ -174,6 +175,8 @@ class V4ForestHParams(object):
self.v4_checkpoint_stats = getattr(self, 'v4_checkpoint_stats', False)
self.v4_use_running_stats_method = getattr(
self, 'v4_use_running_stats_method', False)
self.v4_initialize_average_splits = getattr(
self, 'v4_initialize_average_splits', False)
self.v4_param_file = getattr(self, 'v4_param_file', None)