diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 8eac1243ef6..f03eab510c2 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -445,6 +445,7 @@ tf_kernel_library( "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/core:framework_headers_lib", diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index d9e7a0f4660..64349cfca39 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -383,19 +383,20 @@ class BuildDenseInequalitySplitsOp : public OpKernel { best_gain -= num_elements * state->tree_complexity_regularization(); ObliviousSplitInfo oblivious_split_info; - auto* oblivious_dense_split = oblivious_split_info.mutable_split_node() - ->mutable_dense_float_binary_split(); + auto* oblivious_dense_split = + oblivious_split_info.mutable_split_node() + ->mutable_oblivious_dense_float_binary_split(); oblivious_dense_split->set_feature_column(state->feature_column_group_id()); oblivious_dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); (*gains)(0) = best_gain; for (int root_idx = 0; root_idx < num_elements; root_idx++) { - auto* left_children = oblivious_split_info.add_children_leaves(); - auto* right_children = oblivious_split_info.add_children_leaves(); + auto* left_child = oblivious_split_info.add_children(); + auto* right_child = oblivious_split_info.add_children(); - state->FillLeaf(best_left_node_stats[root_idx], left_children); - state->FillLeaf(best_right_node_stats[root_idx], right_children); + state->FillLeaf(best_left_node_stats[root_idx], left_child); + state->FillLeaf(best_right_node_stats[root_idx], right_child); const int start_index = partition_boundaries[root_idx]; (*output_partition_ids)(root_idx) = partition_ids(start_index); diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 6d9a6ee5a0d..bb5ae78d9bf 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -15,6 +15,7 @@ #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" +#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -26,6 +27,7 @@ namespace boosted_trees { namespace { +using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearningRateConfig; using boosted_trees::trees::Leaf; using boosted_trees::trees::TreeNode; @@ -42,6 +44,9 @@ struct SplitCandidate { // Split info. learner::SplitInfo split_info; + + // Oblivious split info. + learner::ObliviousSplitInfo oblivious_split_info; }; // Checks that the leaf is not empty. @@ -343,7 +348,12 @@ class GrowTreeEnsembleOp : public OpKernel { OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t)); float learning_rate = learning_rate_t->scalar()(); - // Read seed that was used for dropout. + // Read the weak learner type to use. + const Tensor* weak_learner_type_t; + OP_REQUIRES_OK(context, + context->input("weak_learner_type", &weak_learner_type_t)); + const int32 weak_learner_type = weak_learner_type_t->scalar()(); + const Tensor* seed_t; OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t)); // Cast seed to uint64. @@ -363,9 +373,18 @@ class GrowTreeEnsembleOp : public OpKernel { // Find best splits for each active partition. std::map best_splits; - FindBestSplitsPerPartition(context, partition_ids_list, gains_list, - splits_list, &best_splits); - + switch (weak_learner_type) { + case LearnerConfig::NORMAL_DECISION_TREE: { + FindBestSplitsPerPartitionNormal(context, partition_ids_list, + gains_list, splits_list, &best_splits); + break; + } + case LearnerConfig::OBLIVIOUS_DECISION_TREE: { + FindBestSplitsPerPartitionOblivious(context, gains_list, splits_list, + &best_splits); + break; + } + } // No-op if no new splits can be considered. if (best_splits.empty()) { LOG(WARNING) << "Not growing tree ensemble as no good splits were found."; @@ -377,25 +396,34 @@ class GrowTreeEnsembleOp : public OpKernel { OP_REQUIRES_OK(context, context->input("max_tree_depth", &max_tree_depth_t)); const int32 max_tree_depth = max_tree_depth_t->scalar()(); - // Update and retrieve the growable tree. // If the tree is fully built and dropout was applied, it also adjusts the // weights of dropped and the last tree. boosted_trees::trees::DecisionTreeConfig* const tree_config = UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, - dropout_seed, max_tree_depth); - + dropout_seed, max_tree_depth, + weak_learner_type); // Split tree nodes. - for (auto& split_entry : best_splits) { - SplitTreeNode(split_entry.first, &split_entry.second, tree_config, - ensemble_resource); + switch (weak_learner_type) { + case LearnerConfig::NORMAL_DECISION_TREE: { + for (auto& split_entry : best_splits) { + SplitTreeNode(split_entry.first, &split_entry.second, tree_config, + ensemble_resource); + } + break; + } + case LearnerConfig::OBLIVIOUS_DECISION_TREE: { + SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource); + } } - // Post-prune finalized tree if needed. if (learner_config_.pruning_mode() == boosted_trees::learner::LearnerConfig::POST_PRUNE && ensemble_resource->LastTreeMetadata()->is_finalized()) { VLOG(2) << "Post-pruning finalized tree."; + if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) { + LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees."; + } PruneTree(tree_config); // If after post-pruning the whole tree has no gain, remove the tree @@ -409,10 +437,9 @@ class GrowTreeEnsembleOp : public OpKernel { private: // Helper method which effectively does a reduce over all split candidates // and finds the best split for each partition. - void FindBestSplitsPerPartition( - OpKernelContext* const context, - const OpInputList& partition_ids_list, const OpInputList& gains_list, - const OpInputList& splits_list, + void FindBestSplitsPerPartitionNormal( + OpKernelContext* const context, const OpInputList& partition_ids_list, + const OpInputList& gains_list, const OpInputList& splits_list, std::map* best_splits) { // Find best split per partition going through every feature candidate. // TODO(salehay): Is this worth parallelizing? @@ -446,6 +473,90 @@ class GrowTreeEnsembleOp : public OpKernel { } } + void FindBestSplitsPerPartitionOblivious( + OpKernelContext* const context, const OpInputList& gains_list, + const OpInputList& splits_list, + std::map* best_splits) { + // Find best split per partition going through every feature candidate. + for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { + const auto& gains = gains_list[handler_id].vec(); + const auto& splits = splits_list[handler_id].vec(); + OP_REQUIRES(context, gains.size() == 1, + errors::InvalidArgument( + "Gains size must be one for oblivious weak learner: ", + gains.size(), " != ", 1)); + OP_REQUIRES(context, splits.size() == 1, + errors::InvalidArgument( + "Splits size must be one for oblivious weak learner: ", + splits.size(), " != ", 1)); + // Get current split candidate. + const auto& gain = gains(0); + const auto& serialized_split = splits(0); + SplitCandidate split; + split.handler_id = handler_id; + split.gain = gain; + OP_REQUIRES( + context, split.oblivious_split_info.ParseFromString(serialized_split), + errors::InvalidArgument("Unable to parse oblivious split info.")); + + auto split_info = split.oblivious_split_info; + CHECK(split_info.children_size() % 2 == 0) + << "The oblivious split should generate an even number of children: " + << split_info.children_size(); + + // If every node is pure, then we shouldn't split. + bool only_pure_nodes = true; + for (int idx = 0; idx < split_info.children_size(); idx += 2) { + if (IsLeafWellFormed(*split_info.mutable_children(idx)) && + IsLeafWellFormed(*split_info.mutable_children(idx + 1))) { + only_pure_nodes = false; + break; + } + } + if (only_pure_nodes) { + VLOG(1) << "The oblivious split does not actually split anything."; + continue; + } + + // Don't consider negative splits if we're pre-pruning the tree. + if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE && + gain < 0) { + continue; + } + + // Take the split if we don't have a candidate yet. + auto best_split_it = best_splits->find(0); + if (best_split_it == best_splits->end()) { + best_splits->insert(std::make_pair(0, std::move(split))); + continue; + } + + // Determine if we should update best split. + SplitCandidate& best_split = best_split_it->second; + trees::TreeNode current_node = split_info.split_node(); + trees::TreeNode best_node = best_split.oblivious_split_info.split_node(); + if (TF_PREDICT_FALSE(gain == best_split.gain)) { + // Tie break on node case preferring simpler tree node types. + VLOG(2) << "Attempting to tie break with smaller node case. " + << "(current split: " << current_node.node_case() + << ", best split: " << best_node.node_case() << ")"; + if (current_node.node_case() < best_node.node_case()) { + best_split = std::move(split); + } else if (current_node.node_case() == best_node.node_case()) { + // Tie break on handler Id. + VLOG(2) << "Tie breaking with higher handler Id. " + << "(current split: " << handler_id + << ", best split: " << best_split.handler_id << ")"; + if (handler_id > best_split.handler_id) { + best_split = std::move(split); + } + } + } else if (gain > best_split.gain) { + best_split = std::move(split); + } + } + } + void UpdateTreeWeightsIfDropout( boosted_trees::models::DecisionTreeEnsembleResource* const ensemble_resource, @@ -501,7 +612,7 @@ class GrowTreeEnsembleOp : public OpKernel { boosted_trees::models::DecisionTreeEnsembleResource* const ensemble_resource, const float learning_rate, const uint64 dropout_seed, - const int32 max_tree_depth) { + const int32 max_tree_depth, const int32 weak_learner_type) { const auto num_trees = ensemble_resource->num_trees(); if (num_trees <= 0 || ensemble_resource->LastTreeMetadata()->is_finalized()) { @@ -647,6 +758,60 @@ class GrowTreeEnsembleOp : public OpKernel { } } + void SplitTreeLayer( + SplitCandidate* split, + boosted_trees::trees::DecisionTreeConfig* tree_config, + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { + int depth = 0; + while (depth < tree_config->nodes_size() && + tree_config->nodes(depth).node_case() != TreeNode::kLeaf) { + depth++; + } + CHECK(tree_config->nodes_size() > 0) + << "A tree must have at least one dummy leaf."; + // The number of new children. + int num_children = 1 << (depth + 1); + auto split_info = split->oblivious_split_info; + CHECK(num_children == split_info.children_size()) + << "Wrong number of new children: " << num_children + << " != " << split_info.children_size(); + for (int idx = 0; idx < num_children; idx += 2) { + // Old leaf is at position depth + idx / 2. + trees::Leaf old_leaf = + *tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf(); + // Update left leaf. + *split_info.mutable_children(idx) = + *MergeLeafWeights(old_leaf, split_info.mutable_children(idx)); + // Update right leaf. + *split_info.mutable_children(idx + 1) = + *MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1)); + } + TreeNodeMetadata* split_metadata = + split_info.mutable_split_node()->mutable_node_metadata(); + split_metadata->set_gain(split->gain); + + TreeNode new_split = *split_info.mutable_split_node(); + // Move old children to metadata. + for (int idx = depth; idx < tree_config->nodes_size(); idx++) { + *new_split.mutable_node_metadata()->add_original_oblivious_leaves() = + *tree_config->mutable_nodes(idx)->mutable_leaf(); + } + // Add the new split to the tree_config in place before the children start. + *tree_config->mutable_nodes(depth) = new_split; + // Add the new children + int nodes_size = tree_config->nodes_size(); + for (int idx = 0; idx < num_children; idx++) { + if (idx + depth + 1 < nodes_size) { + // Update leaves that were already there. + *tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() = + *split_info.mutable_children(idx); + } else { + // Add new leaves. + *tree_config->add_nodes()->mutable_leaf() = + *split_info.mutable_children(idx); + } + } + } void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) { // No-op if tree is empty. if (tree_config->nodes_size() <= 0) { diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 6572f2f414b..d9caebb645a 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -258,8 +258,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): oblivious_split_info = split_info_pb2.ObliviousSplitInfo() oblivious_split_info.ParseFromString(splits[0]) - split_node = oblivious_split_info.split_node.dense_float_binary_split - + split_node = oblivious_split_info.split_node + split_node = split_node.oblivious_dense_float_binary_split self.assertAllClose(0.3, split_node.threshold, 0.00001) self.assertEqual(0, split_node.feature_column) @@ -279,8 +279,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) expected_bias_gain_0 = 0.46043165467625896 - left_child = oblivious_split_info.children_leaves[0].vector - right_child = oblivious_split_info.children_leaves[1].vector + left_child = oblivious_split_info.children[0].vector + right_child = oblivious_split_info.children[1].vector self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) @@ -296,8 +296,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): # (-4 + 0.1) ** 2 / (0.13 + 1) expected_bias_gain_1 = 13.460176991150442 - left_child = oblivious_split_info.children_leaves[2].vector - right_child = oblivious_split_info.children_leaves[3].vector + left_child = oblivious_split_info.children[2].vector + right_child = oblivious_split_info.children[3].vector self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc index 0e5578693a7..3ed6c5c04d6 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= +#include + #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" #include "tensorflow/core/platform/macros.h" -#include - namespace tensorflow { namespace boosted_trees { namespace trees { @@ -28,14 +28,15 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { return kInvalidLeaf; } - // Traverse tree starting at the provided sub-root. int32 node_id = sub_root_id; + // The index of the leave that holds this example in the oblivious case. + int oblivious_leaf_idx = 0; while (true) { const auto& current_node = config.nodes(node_id); switch (current_node.node_case()) { case TreeNode::kLeaf: { - return node_id; + return node_id + oblivious_leaf_idx; } case TreeNode::kDenseFloatBinarySplit: { const auto& split = current_node.dense_float_binary_split(); @@ -100,6 +101,16 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, } break; } + case TreeNode::kObliviousDenseFloatBinarySplit: { + const auto& split = current_node.oblivious_dense_float_binary_split(); + oblivious_leaf_idx <<= 1; + if (example.dense_float_features[split.feature_column()] > + split.threshold()) { + oblivious_leaf_idx++; + } + node_id++; + break; + } case TreeNode::NODE_NOT_SET: { LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString(); break; @@ -165,6 +176,11 @@ void DecisionTree::LinkChildren(const std::vector& children, split->set_right_id(*++children_it); break; } + case TreeNode::kObliviousDenseFloatBinarySplit: { + LOG(QFATAL) + << "Not implemented for the ObliviousDenseFloatBinarySplit case."; + break; + } case TreeNode::NODE_NOT_SET: { LOG(QFATAL) << "A non-set node cannot have children."; break; @@ -199,6 +215,11 @@ std::vector DecisionTree::GetChildren(const TreeNode& node) { const auto& split = node.categorical_id_set_membership_binary_split(); return {split.left_id(), split.right_id()}; } + case TreeNode::kObliviousDenseFloatBinarySplit: { + LOG(QFATAL) + << "Not implemented for the ObliviousDenseFloatBinarySplit case."; + return {}; + } case TreeNode::NODE_NOT_SET: { return {}; } diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc index 22ac9edb72e..604ec8e0bfa 100644 --- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc @@ -57,6 +57,7 @@ REGISTER_OP("GrowTreeEnsemble") .Input("learning_rate: float") .Input("dropout_seed: int64") .Input("max_tree_depth: int32") + .Input("weak_learner_type: int32") .Input("partition_ids: num_handlers * int32") .Input("gains: num_handlers * float") .Input("splits: num_handlers * string") @@ -82,6 +83,7 @@ tree_ensemble_handle: Handle to the ensemble variable. stamp_token: Stamp token for validating operation consistency. next_stamp_token: Stamp token to be used for the next iteration. learning_rate: Scalar learning rate. +weak_learner_type: The type of weak learner to use. partition_ids: List of Rank 1 Tensors containing partition Id per candidate. gains: List of Rank 1 Tensors containing gains per candidate. splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate. diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto index 850340f5c20..65448996bff 100644 --- a/tensorflow/contrib/boosted_trees/proto/split_info.proto +++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto @@ -19,8 +19,6 @@ message SplitInfo { } message ObliviousSplitInfo { - // The split node with the feature_column and threshold defined. tensorflow.boosted_trees.trees.TreeNode split_node = 1; - // The new leaves of the tree. - repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2; + repeated tensorflow.boosted_trees.trees.Leaf children = 2; } diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 81411aa84ae..500909bf2a1 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -15,6 +15,7 @@ message TreeNode { CategoricalIdBinarySplit categorical_id_binary_split = 5; CategoricalIdSetMembershipBinarySplit categorical_id_set_membership_binary_split = 6; + ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7; } TreeNodeMetadata node_metadata = 777; } @@ -26,6 +27,9 @@ message TreeNodeMetadata { // The original leaf node before this node was split. Leaf original_leaf = 2; + + // The original layer of leaves before that layer was converted to a split. + repeated Leaf original_oblivious_leaves = 3; } // Leaves can either hold dense or sparse information. @@ -101,6 +105,17 @@ message CategoricalIdSetMembershipBinarySplit { int32 right_id = 4; } +// Split rule for dense float features in the oblivious case. +message ObliviousDenseFloatBinarySplit { + // Float feature column and split threshold describing + // the rule feature <= threshold. + int32 feature_column = 1; + float threshold = 2; + // We don't store children ids, because either the next node represents the + // whole next layer of the tree or starting with the next node we only have + // leaves. +} + // DecisionTreeConfig describes a list of connected nodes. // Node 0 must be the root and can carry any payload including a leaf // in the case of representing the bias. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index e39e1de8d19..572717e216b 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -91,6 +91,27 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight): return split.SerializeToString() +def _gen_dense_oblivious_split_info(fc, threshold, leave_weights): + split_str = """ + split_node { + oblivious_dense_float_binary_split { + feature_column: %d + threshold: %f + } + }""" % (fc, threshold) + for weight in leave_weights: + split_str += """ + children { + vector { + value: %f + } + }""" % ( + weight) + split = split_info_pb2.ObliviousSplitInfo() + text_format.Merge(split_str, split) + return split.SerializeToString() + + def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight): split_str = """ split_node { @@ -324,7 +345,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -383,6 +405,115 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(stats.attempted_layers, 1) self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowEmptyEnsembleObliviousCase(self): + """Test growing an empty ensemble in the oblivious case.""" + with self.test_session() as session: + # Create empty ensemble. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=1, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + + # Prepare handler inputs. + # Note that handlers 1 & 3 have the same gain but different splits. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([7.62], dtype=np.float32) + handler1_split = [ + _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143]) + ] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([0.63], dtype=np.float32) + handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([7.62], dtype=np.float32) + handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])] + + # Grow tree ensemble. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=0.1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config.SerializeToString(), + dropout_seed=123, + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + session.run(grow_op) + + # Expect the split with bigger handler_id, i.e. handler 3 to be chosen. + # The grown tree should be finalized as max tree depth is 1. + new_stamp, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + stats = session.run( + training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1)) + tree_ensemble_config.ParseFromString(serialized) + expected_result = """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 0 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + leaf { + vector { + value: -4.375 + } + } + } + nodes { + leaf { + vector { + value: 7.143 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + } + """ + self.assertEqual(new_stamp, 1) + self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_layers, 1) + self.assertEqual(stats.active_tree, 1) + self.assertEqual(stats.active_layer, 1) + self.assertEqual(stats.attempted_trees, 1) + self.assertEqual(stats.attempted_layers, 1) + self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowExistingEnsembleTreeNotFinalized(self): """Test growing an existing ensemble with the last tree not finalized.""" with self.test_session() as session: @@ -476,7 +607,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -661,7 +793,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -798,7 +931,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the ensemble to be empty. @@ -869,7 +1003,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the simpler split from handler 1 to be chosen. @@ -971,7 +1106,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1053,7 +1189,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the ensemble to be empty as post-pruning will prune @@ -1120,7 +1257,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the split from handler 2 to be chosen despite the negative gain. @@ -1200,7 +1338,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the negative gain split of partition 1 to be pruned and the @@ -1371,7 +1510,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect the split for partition 1 to be chosen from handler 1 and @@ -1470,6 +1610,193 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(stats.attempted_layers, 2) self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowEnsembleTreeLayerByLayerObliviousCase(self): + """Test growing an existing ensemble with the last tree not finalized.""" + with self.test_session() as session: + # Create existing ensemble with one root split + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + leaf { + vector { + value: 7.143 + } + } + } + nodes { + leaf { + vector { + value: -4.375 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 1 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=3, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER) + + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([1.4], dtype=np.float32) + handler1_split = [ + _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5]) + ] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([2.7], dtype=np.float32) + handler2_split = [ + _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]), + ] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([1.7], dtype=np.float32) + handler3_split = [ + _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1]) + ] + + # Grow tree ensemble layer by layer. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=0.1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config.SerializeToString(), + dropout_seed=123, + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + session.run(grow_op) + + # Expect the split for partition 1 to be chosen from handler 1 and + # the split for partition 2 to be chosen from handler 2. + # The grown tree should not be finalized as max tree depth is 3 and + # it's only grown 2 layers. + # The partition 1 split weights get added to original leaf weight 7.143. + # The partition 2 split weights get added to original leaf weight -4.375. + new_stamp, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + stats = session.run( + training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1)) + tree_ensemble_config.ParseFromString(serialized) + expected_result = """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 0 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 7.383 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 2 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + } + """ + self.assertEqual(new_stamp, 1) + self.assertEqual(stats.num_trees, 0) + self.assertEqual(stats.num_layers, 2) + self.assertEqual(stats.active_tree, 1) + self.assertEqual(stats.active_layer, 2) + self.assertEqual(stats.attempted_trees, 1) + self.assertEqual(stats.attempted_layers, 2) + self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowExistingEnsembleTreeFinalizedWithDropout(self): """Test growing an existing ensemble with the last tree finalized.""" with self.test_session() as session: @@ -1575,7 +1902,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) # Expect a new tree to be added with the split from handler 1. @@ -1700,7 +2028,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): learner_config=learner_config.SerializeToString(), dropout_seed=123, center_bias=True, - max_tree_depth=learner_config.constraints.max_tree_depth) + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE) session.run(grow_op) _, serialized = session.run( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 2f75d8aa99c..97743ba255a 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -1076,7 +1076,8 @@ class GradientBoostedDecisionTreeModel(object): learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, center_bias=self._center_bias, - max_tree_depth=self._max_tree_depth) + max_tree_depth=self._max_tree_depth, + weak_learner_type=self._learner_config.weak_learner_type) def _grow_ensemble_not_ready_fn(): # Don't grow the ensemble, just update the stamp. @@ -1091,7 +1092,8 @@ class GradientBoostedDecisionTreeModel(object): learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, center_bias=self._center_bias, - max_tree_depth=self._max_tree_depth) + max_tree_depth=self._max_tree_depth, + weak_learner_type=self._learner_config.weak_learner_type) def _grow_ensemble_fn(): # Conditionally grow an ensemble depending on whether the splits