From 1d825cff12c1def7626c62c26e911f5b529b8b66 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Fri, 16 Aug 2019 13:33:25 -0700 Subject: [PATCH] Enable equality split for UpdateEnsembleV2. PiperOrigin-RevId: 263835411 --- tensorflow/core/kernels/boosted_trees/BUILD | 2 + .../kernels/boosted_trees/boosted_trees.proto | 20 +- .../core/kernels/boosted_trees/resources.cc | 64 ++- .../core/kernels/boosted_trees/resources.h | 25 +- .../core/kernels/boosted_trees/stats_ops.cc | 24 +- .../kernels/boosted_trees/training_ops.cc | 160 +++--- .../core/kernels/boosted_trees/tree_helper.h | 21 + .../boosted_trees/stats_ops_test.py | 7 +- .../boosted_trees/training_ops_test.py | 491 +++++++++++++++++- 9 files changed, 691 insertions(+), 123 deletions(-) diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD index 3c2bc929cc3..30f7697187e 100644 --- a/tensorflow/core/kernels/boosted_trees/BUILD +++ b/tensorflow/core/kernels/boosted_trees/BUILD @@ -46,6 +46,7 @@ cc_library( srcs = ["resources.cc"], hdrs = ["resources.h"], deps = [ + ":tree_helper", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", @@ -95,6 +96,7 @@ tf_kernel_library( ":tree_helper", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto index 4e0f4c7d56c..cd64effa5d8 100644 --- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto +++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto @@ -48,6 +48,18 @@ message SparseVector { repeated float value = 2; } +enum SplitTypeWithDefault { + INEQUALITY_DEFAULT_LEFT = 0; + INEQUALITY_DEFAULT_RIGHT = 1; + EQUALITY_DEFAULT_RIGHT = 3; +} + +enum DefaultDirection { + // Left is the default direction. + DEFAULT_LEFT = 0; + DEFAULT_RIGHT = 1; +} + message BucketizedSplit { // Float feature column and split threshold describing // the rule feature <= threshold. @@ -56,11 +68,6 @@ message BucketizedSplit { // If feature column is multivalent, this holds the index of the dimension // for the split. Defaults to 0. int32 dimension_id = 5; - enum DefaultDirection { - // Left is the default direction. - DEFAULT_LEFT = 0; - DEFAULT_RIGHT = 1; - } // default direction for missing values. DefaultDirection default_direction = 6; @@ -75,6 +82,9 @@ message CategoricalSplit { // value. int32 feature_id = 1; int32 value = 2; + // If feature column is multivalent, this holds the index of the dimension + // for the split. Defaults to 0. + int32 dimension_id = 5; // Node children indexing into a contiguous // vector of nodes starting from the root. diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc index dadbfe47c52..85bd64e6802 100644 --- a/tensorflow/core/kernels/boosted_trees/resources.cc +++ b/tensorflow/core/kernels/boosted_trees/resources.cc @@ -14,8 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/boosted_trees/resources.h" + #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" +#include "tensorflow/core/kernels/boosted_trees/tree_helper.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" @@ -265,11 +267,50 @@ int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight, } void BoostedTreesEnsembleResource::AddBucketizedSplitNode( - const int32 tree_id, const int32 node_id, const int32 feature_id, - const int32 dimension_id, const int32 threshold, const float gain, - const float left_contrib, const float right_contrib, int32* left_node_id, - int32* right_node_id) { + const int32 tree_id, + const std::pair& split_entry, + int32* left_node_id, int32* right_node_id) { + const auto candidate = split_entry.second; + auto* node = AddLeafNodes(tree_id, split_entry, left_node_id, right_node_id); + auto* new_split = node->mutable_bucketized_split(); + new_split->set_feature_id(candidate.feature_idx); + new_split->set_threshold(candidate.threshold); + new_split->set_dimension_id(candidate.dimension_id); + new_split->set_left_id(*left_node_id); + new_split->set_right_id(*right_node_id); + + boosted_trees::SplitTypeWithDefault split_type_with_default; + bool parsed = boosted_trees::SplitTypeWithDefault_Parse( + candidate.split_type, &split_type_with_default); + DCHECK(parsed); + if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) { + new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT); + } else { + new_split->set_default_direction(boosted_trees::DEFAULT_LEFT); + } +} + +void BoostedTreesEnsembleResource::AddCategoricalSplitNode( + const int32 tree_id, + const std::pair& split_entry, + int32* left_node_id, int32* right_node_id) { + const auto candidate = split_entry.second; + auto* node = AddLeafNodes(tree_id, split_entry, left_node_id, right_node_id); + auto* new_split = node->mutable_categorical_split(); + new_split->set_feature_id(candidate.feature_idx); + new_split->set_value(candidate.threshold); + new_split->set_dimension_id(candidate.dimension_id); + new_split->set_left_id(*left_node_id); + new_split->set_right_id(*right_node_id); +} + +boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes( + const int32 tree_id, + const std::pair& split_entry, + int32* left_node_id, int32* right_node_id) { auto* tree = tree_ensemble_->mutable_trees(tree_id); + const auto node_id = split_entry.first; + const auto candidate = split_entry.second; auto* node = tree->mutable_nodes(node_id); DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf); float prev_node_value = node->leaf().scalar(); @@ -282,16 +323,13 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode( node->mutable_metadata()->mutable_original_leaf()->Swap( node->mutable_leaf()); } - node->mutable_metadata()->set_gain(gain); - auto* new_split = node->mutable_bucketized_split(); - new_split->set_feature_id(feature_id); - new_split->set_threshold(threshold); - new_split->set_dimension_id(dimension_id); - new_split->set_left_id(*left_node_id); - new_split->set_right_id(*right_node_id); + node->mutable_metadata()->set_gain(candidate.gain); // TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE. - left_node->mutable_leaf()->set_scalar(prev_node_value + left_contrib); - right_node->mutable_leaf()->set_scalar(prev_node_value + right_contrib); + left_node->mutable_leaf()->set_scalar(prev_node_value + + candidate.left_node_contrib); + right_node->mutable_leaf()->set_scalar(prev_node_value + + candidate.right_node_contrib); + return node; } void BoostedTreesEnsembleResource::Reset() { diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h index ce7014d111d..e1ca8cc5a6b 100644 --- a/tensorflow/core/kernels/boosted_trees/resources.h +++ b/tensorflow/core/kernels/boosted_trees/resources.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/boosted_trees/tree_helper.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" @@ -25,6 +26,7 @@ namespace tensorflow { // Forward declaration for proto class TreeEnsemble namespace boosted_trees { class TreeEnsemble; +class Node; } // namespace boosted_trees // A StampedResource is a resource that has a stamp token associated with it. @@ -105,13 +107,17 @@ class BoostedTreesEnsembleResource : public StampedResource { // Adds new tree with one node to the ensemble and sets node's value to logits int32 AddNewTreeWithLogits(const float weight, const float logits); - // Grows the tree by adding a split and leaves. - void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id, - const int32 feature_id, const int32 dimension_id, - const int32 threshold, const float gain, - const float left_contrib, - const float right_contrib, int32* left_node_id, - int32* right_node_id); + // Grows the tree by adding a bucketized split and leaves. + void AddBucketizedSplitNode( + const int32 tree_id, + const std::pair& split_entry, + int32* left_node_id, int32* right_node_id); + + // Grows the tree by adding a categorical split and leaves. + void AddCategoricalSplitNode( + const int32 tree_id, + const std::pair& split_entry, + int32* left_node_id, int32* right_node_id); // Retrieves tree weights and returns as a vector. // It involves a copy, so should be called only sparingly (like once per @@ -167,6 +173,11 @@ class BoostedTreesEnsembleResource : public StampedResource { protobuf::Arena arena_; mutex mu_; boosted_trees::TreeEnsemble* tree_ensemble_; + + boosted_trees::Node* AddLeafNodes( + int32 tree_id, + const std::pair& split_entry, + int32* left_node_id, int32* right_node_id); }; } // namespace tensorflow diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc index 03b9809d97a..45dc248bffd 100644 --- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc @@ -20,16 +20,12 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { -// TODO(tanzheny): Make these const as proto enum. -const char kInequalityDefaultLeft[] = "inequality_default_left"; -const char kInequalityDefaultRight[] = "inequality_default_right"; -const char kEqualityDefaultRight[] = "equality_default_right"; - using Matrix = Eigen::Matrix; using ConstMatrixMap = Eigen::Map; @@ -459,6 +455,12 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { cum_hess.push_back(total_hess); } } + const string kInequalityDefaultLeft = + boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_LEFT); + const string kInequalityDefaultRight = + boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_RIGHT); // Iterate from left to right, excluding default bucket. for (int bucket = 0; bucket < num_buckets; ++bucket) { @@ -491,6 +493,9 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim, string* best_split_type, Eigen::VectorXf* best_contrib_for_left, Eigen::VectorXf* best_contrib_for_right) { + const string kEqualityDefaultRight = + boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::EQUALITY_DEFAULT_RIGHT); for (int f_dim = 0; f_dim < feature_dims; ++f_dim) { for (int bucket = 0; bucket < num_buckets; ++bucket) { ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0), @@ -734,7 +739,8 @@ class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel { float best_gain = std::numeric_limits::lowest(); float best_bucket = 0; float best_f_dim = 0; - string best_split_type = kInequalityDefaultLeft; + string best_split_type = boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_LEFT); float best_contrib_for_left = 0.0; float best_contrib_for_right = 0.0; // the sum of gradients including default bucket. @@ -801,7 +807,8 @@ class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel { best_gain = gain_for_left + gain_for_right; best_bucket = bucket_id; best_f_dim = feature_dim; - best_split_type = kInequalityDefaultRight; + best_split_type = boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_RIGHT); best_contrib_for_left = contrib_for_left[0]; best_contrib_for_right = contrib_for_right[0]; } @@ -818,7 +825,8 @@ class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel { best_gain = gain_for_left + gain_for_right; best_bucket = bucket_id; best_f_dim = feature_dim; - best_split_type = kInequalityDefaultLeft; + best_split_type = boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_LEFT); best_contrib_for_left = contrib_for_left[0]; best_contrib_for_right = contrib_for_right[0]; } diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc index 7bcfb339e7c..dd8abcea65c 100644 --- a/tensorflow/core/kernels/boosted_trees/training_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" #include "tensorflow/core/kernels/boosted_trees/resources.h" #include "tensorflow/core/kernels/boosted_trees/tree_helper.h" #include "tensorflow/core/lib/core/refcount.h" @@ -26,19 +27,6 @@ namespace { constexpr float kLayerByLayerTreeWeight = 1.0; constexpr float kMinDeltaForCenterBias = 0.01; -// TODO(nponomareva, youngheek): consider using vector. -struct SplitCandidate { - SplitCandidate() {} - - // Index in the list of the feature ids. - int64 feature_idx; - - // Index in the tensor of node_ids for the feature with idx feature_idx. - int64 candidate_idx; - - float gain; -}; - enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 }; } // namespace @@ -91,9 +79,10 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { const auto learning_rate = learning_rate_t->scalar()(); // Find best splits for each active node. - std::map best_splits; - FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids, - &best_splits); + std::map best_splits; + FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list, + thresholds_list, left_node_contribs, + right_node_contribs, feature_ids, &best_splits); int32 current_tree = UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource); @@ -113,17 +102,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { int32 node_id_start = ensemble_resource->GetNumNodes(current_tree); // Add the splits to the tree. for (auto& split_entry : best_splits) { - const int32 node_id = split_entry.first; - const SplitCandidate& candidate = split_entry.second; - - const int64 feature_idx = candidate.feature_idx; - const int64 candidate_idx = candidate.candidate_idx; - - const int32 feature_id = feature_ids(feature_idx); - const int32 threshold = - thresholds_list[feature_idx].vec()(candidate_idx); - const float gain = gains_list[feature_idx].vec()(candidate_idx); - + const float gain = split_entry.second.gain; if (pruning_mode_ == kPrePruning) { // Don't consider negative splits if we're pre-pruning the tree. // Note that zero-gain splits are acceptable. @@ -131,22 +110,13 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { continue; } } - // For now assume that the weights vectors are one dimensional. - // TODO(nponomareva): change here for multiclass. - const float left_contrib = - learning_rate * - left_node_contribs[feature_idx].matrix()(candidate_idx, 0); - const float right_contrib = - learning_rate * - right_node_contribs[feature_idx].matrix()(candidate_idx, 0); // unused. int32 left_node_id; int32 right_node_id; - ensemble_resource->AddBucketizedSplitNode( - current_tree, node_id, feature_id, 0, threshold, gain, left_contrib, - right_contrib, &left_node_id, &right_node_id); + ensemble_resource->AddBucketizedSplitNode(current_tree, split_entry, + &left_node_id, &right_node_id); split_happened = true; } int32 node_id_end = ensemble_resource->GetNumNodes(current_tree); @@ -196,14 +166,22 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { // Helper method which effectively does a reduce over all split candidates // and finds the best split for each node. void FindBestSplitsPerNode( - OpKernelContext* const context, const OpInputList& node_ids_list, - const OpInputList& gains_list, + OpKernelContext* const context, const float learning_rate, + const OpInputList& node_ids_list, const OpInputList& gains_list, + const OpInputList& thresholds_list, + const OpInputList& left_node_contribs_list, + const OpInputList& right_node_contribs_list, const TTypes::Vec& feature_ids, - std::map* best_split_per_node) { + std::map* best_split_per_node) { // Find best split per node going through every feature candidate. for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) { const auto& node_ids = node_ids_list[feature_idx].vec(); const auto& gains = gains_list[feature_idx].vec(); + const auto& thresholds = thresholds_list[feature_idx].vec(); + const auto& left_node_contribs = + left_node_contribs_list[feature_idx].matrix(); + const auto& right_node_contribs = + right_node_contribs_list[feature_idx].matrix(); for (size_t candidate_idx = 0; candidate_idx < node_ids.size(); ++candidate_idx) { @@ -212,16 +190,24 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel { const auto& gain = gains(candidate_idx); auto best_split_it = best_split_per_node->find(node_id); - SplitCandidate candidate; - candidate.feature_idx = feature_idx; + boosted_trees::SplitCandidate candidate; + candidate.feature_idx = feature_ids(feature_idx); candidate.candidate_idx = candidate_idx; candidate.gain = gain; + candidate.dimension_id = 0; + candidate.threshold = thresholds(candidate_idx); + candidate.left_node_contrib = + learning_rate * left_node_contribs(candidate_idx, 0); + candidate.right_node_contrib = + learning_rate * right_node_contribs(candidate_idx, 0); + candidate.split_type = boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_LEFT); if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && GainsAreEqual(gain, best_split_it->second.gain))) { const auto best_candidate = (*best_split_per_node)[node_id]; - const int32 best_feature_id = feature_ids(best_candidate.feature_idx); - const int32 feature_id = feature_ids(candidate.feature_idx); + const int32 best_feature_id = best_candidate.feature_idx; + const int32 feature_id = candidate.feature_idx; VLOG(2) << "Breaking ties on feature ids and buckets"; // Breaking ties deterministically. if (feature_id < best_feature_id) { @@ -299,9 +285,11 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel { static_cast(pruning_mode_t->scalar()()); // Find best splits for each active node. - std::map best_splits; - FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids, - &best_splits); + std::map best_splits; + FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list, + thresholds_list, dimension_ids_list, + left_node_contribs, right_node_contribs, + split_types_list, feature_ids, &best_splits); int32 current_tree = UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource); @@ -321,19 +309,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel { int32 node_id_start = ensemble_resource->GetNumNodes(current_tree); // Add the splits to the tree. for (auto& split_entry : best_splits) { - const int32 node_id = split_entry.first; - const SplitCandidate& candidate = split_entry.second; - - const int64 feature_idx = candidate.feature_idx; - const int32 feature_id = feature_ids(feature_idx); - - const int64 candidate_idx = candidate.candidate_idx; - - const int32 dimension_id = - dimension_ids_list[feature_idx].vec()(candidate_idx); - const int32 threshold = - thresholds_list[feature_idx].vec()(candidate_idx); - const float gain = gains_list[feature_idx].vec()(candidate_idx); + const float gain = split_entry.second.gain; + const string split_type = split_entry.second.split_type; if (pruning_mode == kPrePruning) { // Don't consider negative splits if we're pre-pruning the tree. @@ -343,22 +320,23 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel { } } - // TODO(crawles): change here for multiclass. - const float left_contrib = - learning_rate * - left_node_contribs[feature_idx].matrix()(candidate_idx, 0); - const float right_contrib = - learning_rate * - right_node_contribs[feature_idx].matrix()(candidate_idx, 0); - // unused. int32 left_node_id; int32 right_node_id; - // TODO(tanzheny): add categorical split. - ensemble_resource->AddBucketizedSplitNode( - current_tree, node_id, feature_id, dimension_id, threshold, gain, - left_contrib, right_contrib, &left_node_id, &right_node_id); + boosted_trees::SplitTypeWithDefault split_type_with_default; + bool parsed = boosted_trees::SplitTypeWithDefault_Parse( + split_type, &split_type_with_default); + DCHECK(parsed); + if (split_type_with_default == boosted_trees::EQUALITY_DEFAULT_RIGHT) { + // Add equality split to the node. + ensemble_resource->AddCategoricalSplitNode( + current_tree, split_entry, &left_node_id, &right_node_id); + } else { + // Add inequality split to the node. + ensemble_resource->AddBucketizedSplitNode( + current_tree, split_entry, &left_node_id, &right_node_id); + } split_happened = true; } int32 node_id_end = ensemble_resource->GetNumNodes(current_tree); @@ -408,32 +386,54 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel { // Helper method which effectively does a reduce over all split candidates // and finds the best split for each node. void FindBestSplitsPerNode( - OpKernelContext* const context, const OpInputList& node_ids_list, - const OpInputList& gains_list, + OpKernelContext* const context, const float learning_rate, + const OpInputList& node_ids_list, const OpInputList& gains_list, + const OpInputList& thresholds_list, const OpInputList& dimension_ids_list, + const OpInputList& left_node_contribs_list, + const OpInputList& right_node_contribs_list, + const OpInputList& split_types_list, const TTypes::Vec& feature_ids, - std::map* best_split_per_node) { + std::map* best_split_per_node) { // Find best split per node going through every feature candidate. for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) { const auto& node_ids = node_ids_list[feature_idx].vec(); const auto& gains = gains_list[feature_idx].vec(); + const auto& thresholds = thresholds_list[feature_idx].vec(); + const auto& dimension_ids = dimension_ids_list[feature_idx].vec(); + const auto& left_node_contribs = + left_node_contribs_list[feature_idx].matrix(); + const auto& right_node_contribs = + right_node_contribs_list[feature_idx].matrix(); + const auto& split_types = split_types_list[feature_idx].vec(); for (size_t candidate_idx = 0; candidate_idx < node_ids.size(); ++candidate_idx) { // Get current split candidate. const auto& node_id = node_ids(candidate_idx); const auto& gain = gains(candidate_idx); + const auto& threshold = thresholds(candidate_idx); + const auto& dimension_id = dimension_ids(candidate_idx); + const auto& split_type = split_types(candidate_idx); auto best_split_it = best_split_per_node->find(node_id); - SplitCandidate candidate; - candidate.feature_idx = feature_idx; + boosted_trees::SplitCandidate candidate; + candidate.feature_idx = feature_ids(feature_idx); candidate.candidate_idx = candidate_idx; candidate.gain = gain; + candidate.threshold = threshold; + candidate.dimension_id = dimension_id; + // TODO(crawles): change here for multiclass. + candidate.left_node_contrib = + learning_rate * left_node_contribs(candidate_idx, 0); + candidate.right_node_contrib = + learning_rate * right_node_contribs(candidate_idx, 0); + candidate.split_type = split_type; if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && GainsAreEqual(gain, best_split_it->second.gain))) { const auto best_candidate = (*best_split_per_node)[node_id]; - const int32 best_feature_id = feature_ids(best_candidate.feature_idx); - const int32 feature_id = feature_ids(candidate.feature_idx); + const int32 best_feature_id = best_candidate.feature_idx; + const int32 feature_id = candidate.feature_idx; VLOG(2) << "Breaking ties on feature ids and buckets"; // Breaking ties deterministically. if (feature_id < best_feature_id) { diff --git a/tensorflow/core/kernels/boosted_trees/tree_helper.h b/tensorflow/core/kernels/boosted_trees/tree_helper.h index 4a4aafd0e52..c007dc195ba 100644 --- a/tensorflow/core/kernels/boosted_trees/tree_helper.h +++ b/tensorflow/core/kernels/boosted_trees/tree_helper.h @@ -25,6 +25,27 @@ limitations under the License. namespace tensorflow { +namespace boosted_trees { +// TODO(nponomareva, youngheek): consider using vector. +struct SplitCandidate { + SplitCandidate() {} + + // Index in the list of the feature ids. + int64 feature_idx = 0; + + // Index in the tensor of node_ids for the feature with idx feature_idx. + int64 candidate_idx = 0; + + float gain = 0.0; + int32 threshold = 0.0; + int32 dimension_id = 0; + float left_node_contrib = 0.0; + float right_node_contrib = 0.0; + // The split type, i.e., with missing value to left/right. + string split_type; +}; +} // namespace boosted_trees + static bool GainsAreEqual(const float g1, const float g2) { const float kTolerance = 1e-15; return std::abs(g1 - g2) < kTolerance; diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py index c5d238ba149..402c6f041e0 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py @@ -28,10 +28,9 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import googletest -_INEQUALITY_DEFAULT_LEFT = 'inequality_default_left'.encode('utf-8') -_INEQUALITY_DEFAULT_RIGHT = 'inequality_default_right'.encode('utf-8') -_EQUALITY_DEFAULT_LEFT = 'equality_default_left'.encode('utf-8') -_EQUALITY_DEFAULT_RIGHT = 'equality_default_right'.encode('utf-8') +_INEQUALITY_DEFAULT_LEFT = 'INEQUALITY_DEFAULT_LEFT'.encode('utf-8') +_INEQUALITY_DEFAULT_RIGHT = 'INEQUALITY_DEFAULT_RIGHT'.encode('utf-8') +_EQUALITY_DEFAULT_RIGHT = 'EQUALITY_DEFAULT_RIGHT'.encode('utf-8') class StatsOpsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py index d6636c92706..b12553ff2ac 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py @@ -26,6 +26,10 @@ from tensorflow.python.ops import boosted_trees_ops from tensorflow.python.ops import resources from tensorflow.python.platform import googletest +_INEQUALITY_DEFAULT_LEFT = 'INEQUALITY_DEFAULT_LEFT'.encode('utf-8') +_INEQUALITY_DEFAULT_RIGHT = 'INEQUALITY_DEFAULT_RIGHT'.encode('utf-8') +_EQUALITY_DEFAULT_RIGHT = 'EQUALITY_DEFAULT_RIGHT'.encode('utf-8') + class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): """Tests for growing tree ensemble from split candidates.""" @@ -158,7 +162,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): feature1_thresholds = np.array([52], dtype=np.int32) feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32) feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32) - feature1_inequality_split_types = np.array(['inequality_default_left']) + feature1_inequality_split_types = np.array([_INEQUALITY_DEFAULT_LEFT]) # Feature split with the highest gain. feature2_nodes = np.array([0], dtype=np.int32) @@ -167,7 +171,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): feature2_thresholds = np.array([7], dtype=np.int32) feature2_left_node_contribs = np.array([[-4.89]], dtype=np.float32) feature2_right_node_contribs = np.array([[5.3]], dtype=np.float32) - feature2_inequality_split_types = np.array(['inequality_default_right']) + feature2_inequality_split_types = np.array([_INEQUALITY_DEFAULT_RIGHT]) # Grow tree ensemble. grow_op = boosted_trees_ops.update_ensemble_v2( @@ -208,6 +212,116 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): dimension_id: 1 left_id: 1 right_id: 2 + default_direction: DEFAULT_RIGHT + } + metadata { + gain: 7.65 + } + } + nodes { + leaf { + scalar: -0.489 + } + } + nodes { + leaf { + scalar: 0.53 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + last_layer_node_start: 0 + last_layer_node_end: 1 + } + """ + self.assertEqual(new_stamp, 1) + self.assertProtoEquals(expected_result, tree_ensemble) + + @test_util.run_deprecated_v1 + def testGrowWithEmptyEnsembleV2EqualitySplit(self): + """Test growing an empty ensemble.""" + with self.cached_session() as session: + # Create empty ensemble. + tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble') + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + feature_ids = [0, 6] + + # Prepare feature inputs. + feature1_nodes = np.array([0], dtype=np.int32) + feature1_gains = np.array([7.62], dtype=np.float32) + feature1_dimensions = np.array([0], dtype=np.int32) + feature1_thresholds = np.array([52], dtype=np.int32) + feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32) + feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32) + feature1_inequality_split_types = np.array([_INEQUALITY_DEFAULT_LEFT]) + + # Feature split with the highest gain. + feature2_nodes = np.array([0], dtype=np.int32) + feature2_gains = np.array([7.65], dtype=np.float32) + feature2_dimensions = np.array([1], dtype=np.int32) + feature2_thresholds = np.array([7], dtype=np.int32) + feature2_left_node_contribs = np.array([[-4.89]], dtype=np.float32) + feature2_right_node_contribs = np.array([[5.3]], dtype=np.float32) + feature2_inequality_split_types = np.array([_EQUALITY_DEFAULT_RIGHT]) + + # Grow tree ensemble. + grow_op = boosted_trees_ops.update_ensemble_v2( + tree_ensemble_handle, + learning_rate=0.1, + pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING, + # Tree will be finalized now, since we will reach depth 1. + max_depth=1, + feature_ids=feature_ids, + dimension_ids=[feature1_dimensions, feature2_dimensions], + node_ids=[feature1_nodes, feature2_nodes], + gains=[feature1_gains, feature2_gains], + thresholds=[feature1_thresholds, feature2_thresholds], + left_node_contribs=[ + feature1_left_node_contribs, feature2_left_node_contribs + ], + right_node_contribs=[ + feature1_right_node_contribs, feature2_right_node_contribs + ], + split_types=[ + feature1_inequality_split_types, feature2_inequality_split_types + ], + ) + session.run(grow_op) + + new_stamp, serialized = session.run(tree_ensemble.serialize()) + + tree_ensemble = boosted_trees_pb2.TreeEnsemble() + tree_ensemble.ParseFromString(serialized) + + # Note that since the tree is finalized, we added a new dummy tree. + expected_result = """ + trees { + nodes { + categorical_split { + feature_id: 6 + value: 7 + dimension_id: 1 + left_id: 1 + right_id: 2 } metadata { gain: 7.65 @@ -537,7 +651,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): feature1_thresholds = np.array([21], dtype=np.int32) feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32) feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32) - feature1_split_types = np.array(['inequality_default_left']) + feature1_split_types = np.array([_INEQUALITY_DEFAULT_LEFT]) feature2_nodes = np.array([1, 2], dtype=np.int32) feature2_gains = np.array([0.63, 2.7], dtype=np.float32) @@ -546,7 +660,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32) feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32) feature2_split_types = np.array( - ['inequality_default_right', 'inequality_default_right']) + [_INEQUALITY_DEFAULT_RIGHT, _INEQUALITY_DEFAULT_RIGHT]) feature3_nodes = np.array([2], dtype=np.int32) feature3_gains = np.array([1.7], dtype=np.float32) @@ -554,7 +668,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): feature3_thresholds = np.array([3], dtype=np.int32) feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32) feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32) - feature3_split_types = np.array(['inequality_default_left']) + feature3_split_types = np.array([_INEQUALITY_DEFAULT_LEFT]) # Grow tree ensemble. grow_op = boosted_trees_ops.update_ensemble_v2( @@ -629,6 +743,212 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): threshold: 7 left_id: 5 right_id: 6 + default_direction: DEFAULT_RIGHT + } + metadata { + gain: 2.7 + original_leaf { + scalar: -0.4375 + } + } + } + nodes { + leaf { + scalar: 0.114 + } + } + nodes { + leaf { + scalar: 0.879 + } + } + nodes { + leaf { + scalar: -0.5875 + } + } + nodes { + leaf { + scalar: -0.2075 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + is_finalized: true + num_layers_grown: 2 + } + tree_metadata { + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + last_layer_node_start: 0 + last_layer_node_end: 1 + } + """ + self.assertEqual(new_stamp, 1) + self.assertProtoEquals(expected_result, tree_ensemble) + + @test_util.run_deprecated_v1 + def testGrowExistingEnsembleTreeV2NotFinalizedEqualitySplit(self): + """Test growing an existing ensemble with the last tree not finalized.""" + with self.cached_session() as session: + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + bucketized_split { + feature_id: 4 + dimension_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + } + } + nodes { + leaf { + scalar: 0.714 + } + } + nodes { + leaf { + scalar: -0.4375 + } + } + } + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + } + """, tree_ensemble_config) + + # Create existing ensemble with one root split + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare feature inputs. + # feature 1 only has a candidate for node 1, feature 2 has candidates + # for both nodes and feature 3 only has a candidate for node 2. + + feature_ids = [0, 1, 0] + + feature1_nodes = np.array([1], dtype=np.int32) + feature1_gains = np.array([1.4], dtype=np.float32) + feature1_dimensions = np.array([0], dtype=np.int32) + feature1_thresholds = np.array([21], dtype=np.int32) + feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32) + feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32) + feature1_split_types = np.array([_INEQUALITY_DEFAULT_LEFT]) + + feature2_nodes = np.array([1, 2], dtype=np.int32) + feature2_gains = np.array([0.63, 2.7], dtype=np.float32) + feature2_dimensions = np.array([1, 3], dtype=np.int32) + feature2_thresholds = np.array([23, 7], dtype=np.int32) + feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32) + feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32) + feature2_split_types = np.array( + [_EQUALITY_DEFAULT_RIGHT, _EQUALITY_DEFAULT_RIGHT]) + + feature3_nodes = np.array([2], dtype=np.int32) + feature3_gains = np.array([1.7], dtype=np.float32) + feature3_dimensions = np.array([0], dtype=np.int32) + feature3_thresholds = np.array([3], dtype=np.int32) + feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32) + feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32) + feature3_split_types = np.array([_INEQUALITY_DEFAULT_LEFT]) + + # Grow tree ensemble. + grow_op = boosted_trees_ops.update_ensemble_v2( + tree_ensemble_handle, + learning_rate=0.1, + pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING, + # tree is going to be finalized now, since we reach depth 2. + max_depth=2, + feature_ids=feature_ids, + dimension_ids=[ + feature1_dimensions, feature2_dimensions, feature3_dimensions + ], + node_ids=[feature1_nodes, feature2_nodes, feature3_nodes], + gains=[feature1_gains, feature2_gains, feature3_gains], + thresholds=[ + feature1_thresholds, feature2_thresholds, feature3_thresholds + ], + left_node_contribs=[ + feature1_left_node_contribs, feature2_left_node_contribs, + feature3_left_node_contribs + ], + right_node_contribs=[ + feature1_right_node_contribs, feature2_right_node_contribs, + feature3_right_node_contribs + ], + split_types=[ + feature1_split_types, feature2_split_types, feature3_split_types + ], + ) + session.run(grow_op) + + # Expect the split for node 1 to be chosen from feature 1 and + # the split for node 2 to be chosen from feature 2. + # The grown tree should be finalized as max tree depth is 2 and we have + # grown 2 layers. + new_stamp, serialized = session.run(tree_ensemble.serialize()) + tree_ensemble = boosted_trees_pb2.TreeEnsemble() + tree_ensemble.ParseFromString(serialized) + + expected_result = """ + trees { + nodes { + bucketized_split { + feature_id: 4 + dimension_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + } + } + nodes { + bucketized_split { + feature_id: 0 + threshold: 21 + dimension_id: 0 + left_id: 3 + right_id: 4 + } + metadata { + gain: 1.4 + original_leaf { + scalar: 0.714 + } + } + } + nodes { + categorical_split { + feature_id: 1 + dimension_id: 3 + value: 7 + left_id: 5 + right_id: 6 } metadata { gain: 2.7 @@ -900,7 +1220,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): feature1_thresholds = np.array([21], dtype=np.int32) feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32) feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32) - feature1_split_types = np.array(['inequality_default_right']) + feature1_split_types = np.array([_INEQUALITY_DEFAULT_RIGHT]) # Grow tree ensemble. grow_op = boosted_trees_ops.update_ensemble_v2( @@ -955,6 +1275,165 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase): threshold: 21 left_id: 1 right_id: 2 + default_direction: DEFAULT_RIGHT + } + metadata { + gain: -1.4 + } + } + nodes { + leaf { + scalar: -0.6 + } + } + nodes { + leaf { + scalar: 0.165 + } + } + } + tree_weights: 0.15 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_layers_grown: 1 + is_finalized: false + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 2 + last_layer_node_start: 1 + last_layer_node_end: 3 + } + """ + self.assertEqual(new_stamp, 1) + self.assertProtoEquals(expected_result, tree_ensemble) + + @test_util.run_deprecated_v1 + def testGrowExistingEnsembleTreeV2FinalizedEqualitySplit(self): + """Test growing an existing ensemble with the last tree finalized.""" + with self.cached_session() as session: + tree_ensemble_config = boosted_trees_pb2.TreeEnsemble() + text_format.Merge( + """ + trees { + nodes { + bucketized_split { + feature_id: 4 + dimension_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + } + } + nodes { + leaf { + scalar: 7.14 + } + } + nodes { + leaf { + scalar: -4.375 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 0.15 + tree_weights: 1.0 + tree_metadata { + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + } + """, tree_ensemble_config) + + # Create existing ensemble with one root split + tree_ensemble = boosted_trees_ops.TreeEnsemble( + 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString()) + tree_ensemble_handle = tree_ensemble.resource_handle + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare feature inputs. + + feature_ids = [75] + + feature1_nodes = np.array([0], dtype=np.int32) + feature1_gains = np.array([-1.4], dtype=np.float32) + feature1_dimensions = np.array([1], dtype=np.int32) + feature1_thresholds = np.array([21], dtype=np.int32) + feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32) + feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32) + feature1_split_types = np.array([_EQUALITY_DEFAULT_RIGHT]) + + # Grow tree ensemble. + grow_op = boosted_trees_ops.update_ensemble_v2( + tree_ensemble_handle, + pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING, + learning_rate=0.1, + max_depth=2, + feature_ids=feature_ids, + dimension_ids=[feature1_dimensions], + node_ids=[feature1_nodes], + gains=[feature1_gains], + thresholds=[feature1_thresholds], + left_node_contribs=[feature1_left_node_contribs], + right_node_contribs=[feature1_right_node_contribs], + split_types=[feature1_split_types]) + session.run(grow_op) + + # Expect a new tree added, with a split on feature 75 + new_stamp, serialized = session.run(tree_ensemble.serialize()) + tree_ensemble = boosted_trees_pb2.TreeEnsemble() + tree_ensemble.ParseFromString(serialized) + + expected_result = """ + trees { + nodes { + bucketized_split { + feature_id: 4 + dimension_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 7.62 + } + } + nodes { + leaf { + scalar: 7.14 + } + } + nodes { + leaf { + scalar: -4.375 + } + } + } + trees { + nodes { + categorical_split { + feature_id: 75 + dimension_id: 1 + value: 21 + left_id: 1 + right_id: 2 } metadata { gain: -1.4