diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesSparseCalculateBestFeatureSplit.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesSparseCalculateBestFeatureSplit.pbtxt new file mode 100644 index 00000000000..ff39bbe5143 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesSparseCalculateBestFeatureSplit.pbtxt @@ -0,0 +1,118 @@ +op { + graph_op_name: "BoostedTreesSparseCalculateBestFeatureSplit" + visibility: HIDDEN + in_arg { + name: "node_id_range" + description: < #include #include "third_party/eigen3/Eigen/Core" @@ -25,6 +26,7 @@ limitations under the License. namespace tensorflow { const char INEQUALITY_DEFAULT_LEFT[] = "inequality_default_left"; +const char INEQUALITY_DEFAULT_RIGHT[] = "inequality_default_right"; // V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOp is V2. class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel { @@ -439,6 +441,306 @@ REGISTER_KERNEL_BUILDER( Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU), BoostedTreesCalculateBestFeatureSplitOp); +// Map from bucket id to vector of statistics. +typedef std::map> BucketMap; +typedef BucketMap::iterator BucketMapIterator; +// Map from feature dimension to BucketMap. +typedef std::map FeatureMap; +typedef FeatureMap::iterator FeatureMapIterator; + +class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel { + public: + explicit BoostedTreesSparseCalculateBestFeatureSplitOp( + OpKernelConstruction* const context) + : OpKernel(context) { + // TODO(crawles): Using logits_dim_ for multi-class split. + OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_)); + // TODO(tanzheny): Using this for equality split. + OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_)); + } + + void Compute(OpKernelContext* const context) override { + // node_id_range + const Tensor* node_id_range_t; + OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t)); + const auto node_id_range = node_id_range_t->vec(); + const int32 node_id_first = node_id_range(0); // inclusive + const int32 node_id_last = node_id_range(1); // exclusive + + const Tensor* stats_summary_indices_t; + OP_REQUIRES_OK(context, context->input("stats_summary_indices", + &stats_summary_indices_t)); + const auto stats_summary_indices = stats_summary_indices_t->matrix(); + const int32 num_sparse_entries = stats_summary_indices_t->dim_size(0); + + const Tensor* stats_summary_values_t; + OP_REQUIRES_OK(context, context->input("stats_summary_values", + &stats_summary_values_t)); + const auto stats_summary_values = stats_summary_values_t->vec(); + + const Tensor* stats_summary_shape_t; + OP_REQUIRES_OK( + context, context->input("stats_summary_shape", &stats_summary_shape_t)); + const auto stats_summary_shape = stats_summary_shape_t->vec(); + const int32 num_buckets = stats_summary_shape(2) - 1; + const int32 stats_dims = stats_summary_shape(3); + + const Tensor* l1_t; + OP_REQUIRES_OK(context, context->input("l1", &l1_t)); + const auto l1 = l1_t->scalar()(); + + const Tensor* l2_t; + OP_REQUIRES_OK(context, context->input("l2", &l2_t)); + const auto l2 = l2_t->scalar()(); + + const Tensor* tree_complexity_t; + OP_REQUIRES_OK(context, + context->input("tree_complexity", &tree_complexity_t)); + const auto tree_complexity = tree_complexity_t->scalar()(); + + const Tensor* min_node_weight_t; + OP_REQUIRES_OK(context, + context->input("min_node_weight", &min_node_weight_t)); + const auto min_node_weight = min_node_weight_t->scalar()(); + + std::vector output_node_ids; + std::vector output_gains; + std::vector output_feature_dimensions; + std::vector output_thresholds; + std::vector output_left_node_contribs; + std::vector output_right_node_contribs; + std::vector output_split_types; + + FeatureMap f_map; + + int32 previous_node_id = -1; + for (int idx = 0; idx < num_sparse_entries; ++idx) { + int32 node_id = stats_summary_indices(idx, 0); + if (node_id != previous_node_id) { + process_node(f_map, &output_node_ids, &output_gains, + &output_feature_dimensions, &output_thresholds, + &output_left_node_contribs, &output_right_node_contribs, + &output_split_types, previous_node_id, min_node_weight, l1, + l2, num_buckets); + f_map.clear(); + } + previous_node_id = node_id; + DCHECK_LE(node_id_first, node_id); + DCHECK_LT(node_id, node_id_last); + const int32 feature_dim = stats_summary_indices(idx, 1); + const int32 bucket_id = stats_summary_indices(idx, 2); + const int32 stat_dim = stats_summary_indices(idx, 3); + std::pair const& f_insert_result = f_map.insert( + FeatureMapIterator::value_type(feature_dim, BucketMap())); + auto& b_map = f_insert_result.first->second; + std::pair const& b_insert_result = + b_map.insert(BucketMapIterator::value_type( + bucket_id, std::vector(stats_dims))); + auto& stats = b_insert_result.first->second; + stats[stat_dim] = stats_summary_values(idx); + } // for node_id + // process the last node id + process_node(f_map, &output_node_ids, &output_gains, + &output_feature_dimensions, &output_thresholds, + &output_left_node_contribs, &output_right_node_contribs, + &output_split_types, previous_node_id, min_node_weight, l1, l2, + num_buckets); + + const int num_nodes = output_node_ids.size(); + // output_node_ids + Tensor* output_node_ids_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes}, + &output_node_ids_t)); + auto output_node_ids_vec = output_node_ids_t->vec(); + + // output_gains + Tensor* output_gains_t; + OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes}, + &output_gains_t)); + auto output_gains_vec = output_gains_t->vec(); + + // output_feature_dimensions + Tensor* output_feature_dimension_t; + OP_REQUIRES_OK(context, + context->allocate_output("feature_dimensions", {num_nodes}, + &output_feature_dimension_t)); + auto output_feature_dimensions_vec = + output_feature_dimension_t->vec(); + + // output_thresholds + Tensor* output_thresholds_t; + OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes}, + &output_thresholds_t)); + auto output_thresholds_vec = output_thresholds_t->vec(); + + // output_left_node_contribs + Tensor* output_left_node_contribs_t; + OP_REQUIRES_OK( + context, context->allocate_output("left_node_contribs", {num_nodes, 1}, + &output_left_node_contribs_t)); + auto output_left_node_contribs_matrix = + output_left_node_contribs_t->matrix(); + + // output_right_node_contribs + Tensor* output_right_node_contribs_t; + OP_REQUIRES_OK( + context, context->allocate_output("right_node_contribs", {num_nodes, 1}, + &output_right_node_contribs_t)); + auto output_right_node_contribs_matrix = + output_right_node_contribs_t->matrix(); + + // split type + Tensor* output_split_types_t; + OP_REQUIRES_OK( + context, context->allocate_output("split_with_default_directions", + {num_nodes}, &output_split_types_t)); + auto output_split_types_vec = output_split_types_t->vec(); + + // Sets output tensors from vectors. + for (int i = 0; i < num_nodes; ++i) { + output_node_ids_vec(i) = output_node_ids[i]; + // Adjust the gains to penalize by tree complexity. + output_gains_vec(i) = output_gains[i] - tree_complexity; + output_feature_dimensions_vec(i) = output_feature_dimensions[i]; + output_thresholds_vec(i) = output_thresholds[i]; + // TODO(crawles): change this for multi-class. + output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i]; + output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i]; + output_split_types_vec(i) = output_split_types[i]; + } + } + + protected: + void process_node(const FeatureMap& f_map, + std::vector* output_node_ids, + std::vector* output_gains, + std::vector* output_feature_dimensions, + std::vector* output_thresholds, + std::vector* output_left_node_contribs, + std::vector* output_right_node_contribs, + std::vector* output_split_types, + const int32 node_id, const float min_node_weight, + const float l1, const float l2, const int32 num_buckets) { + float parent_gain; + Eigen::VectorXf unused(logits_dim_); + Eigen::MatrixXf identity; + identity.setIdentity(1, 1); + + // start processing for previous node id. + float best_gain = std::numeric_limits::lowest(); + float best_bucket = 0; + float best_f_dim = 0; + string best_split_type = INEQUALITY_DEFAULT_LEFT; + float best_contrib_for_left = 0.0; + float best_contrib_for_right = 0.0; + // the sum of gradients including default bucket. + float total_grad = 0; + // the sum of hessians including default bucket. + float total_hess = 0; + + for (auto f_iter = f_map.begin(); f_iter != f_map.end(); ++f_iter) { + const int32 feature_dim = f_iter->first; + const auto buckets_to_stats_map = f_iter->second; + + // The very last bucket contains stats for missing values. + // TODO(crawles): use vector for multi-class. + const float default_grad = + (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end() + ? 0 + : buckets_to_stats_map.at(num_buckets)[0]); + const float default_hess = + (buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end() + ? 0 + : buckets_to_stats_map.at(num_buckets)[1]); + + if (f_iter == f_map.begin()) { + // first get the sum of grads, including default bucket. + for (auto b_iter = buckets_to_stats_map.begin(); + b_iter != buckets_to_stats_map.end(); ++b_iter) { + total_grad += b_iter->second[0]; + total_hess += b_iter->second[1]; + } + if (total_hess < min_node_weight) { + // Do not split the node because not enough avg hessian. + break; + } + CalculateWeightsAndGains(total_grad * identity, total_hess * identity, + l1, l2, &unused, &parent_gain); + } + + float total_left_grad = 0; + float total_left_hess = 0; + for (auto b_iter = buckets_to_stats_map.begin(); + b_iter != buckets_to_stats_map.end(); ++b_iter) { + const int32 bucket_id = b_iter->first; + // total_left_stats should exclude stats from default bucket. + if (bucket_id == num_buckets) { + break; + } + // TODO(crawles): vector for multi-class. + total_left_grad += b_iter->second[0]; + total_left_hess += b_iter->second[1]; + // From left to right, default right. + // Left child. + Eigen::VectorXf contrib_for_left(1); + float gain_for_left; + CalculateWeightsAndGains(total_left_grad * identity, + total_left_hess * identity, l1, l2, + &contrib_for_left, &gain_for_left); + // Right child. + Eigen::VectorXf contrib_for_right(1); + float gain_for_right; + CalculateWeightsAndGains((total_grad - total_left_grad) * identity, + (total_hess - total_left_hess) * identity, l1, + l2, &contrib_for_right, &gain_for_right); + if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) { + best_gain = gain_for_left + gain_for_right; + best_bucket = bucket_id; + best_f_dim = feature_dim; + best_split_type = INEQUALITY_DEFAULT_RIGHT; + best_contrib_for_left = contrib_for_left[0]; + best_contrib_for_right = contrib_for_right[0]; + } + + // From right to left, default left. + CalculateWeightsAndGains((total_left_grad + default_grad) * identity, + (total_left_hess + default_hess) * identity, + l1, l2, &contrib_for_left, &gain_for_left); + CalculateWeightsAndGains( + (total_grad - default_grad - total_left_grad) * identity, + (total_hess - default_hess - total_left_hess) * identity, l1, l2, + &contrib_for_right, &gain_for_right); + if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) { + best_gain = gain_for_left + gain_for_right; + best_bucket = bucket_id; + best_f_dim = feature_dim; + best_split_type = INEQUALITY_DEFAULT_LEFT; + best_contrib_for_left = contrib_for_left[0]; + best_contrib_for_right = contrib_for_right[0]; + } + } // for bucket_id + } // for feature_dim + if (best_gain != std::numeric_limits::lowest()) { + output_node_ids->push_back(node_id); + // Remove the parent gain. + output_gains->push_back(best_gain - parent_gain); + output_feature_dimensions->push_back(best_f_dim); + output_split_types->push_back(best_split_type); + output_thresholds->push_back(best_bucket); + output_left_node_contribs->push_back(best_contrib_for_left); + output_right_node_contribs->push_back(best_contrib_for_right); + } + } + + private: + int logits_dim_; + string split_type_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesSparseCalculateBestFeatureSplit").Device(DEVICE_CPU), + BoostedTreesSparseCalculateBestFeatureSplitOp); + class BoostedTreesMakeStatsSummaryOp : public OpKernel { public: explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context) diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index 8831e703321..b05b2f57898 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -132,6 +132,48 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit") return Status::OK(); }); +REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit") + .Input("node_id_range: int32") + .Input("stats_summary_indices: int32") + .Input("stats_summary_values: float") + .Input("stats_summary_shape: int32") + .Input("l1: float") + .Input("l2: float") + .Input("tree_complexity: float") + .Input("min_node_weight: float") + .Attr("logits_dimension: int >= 1") + .Attr("split_type: {'inequality'} = 'inequality'") + .Output("node_ids: int32") + .Output("gains: float32") + .Output("feature_dimensions: int32") + .Output("thresholds: int32") + .Output("left_node_contribs: float32") + .Output("right_node_contribs: float32") + .Output("split_with_default_directions: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle node_id_range_shape; + shape_inference::ShapeHandle unused_shape; + // node id range is rank 1 with 2 values. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); + TF_RETURN_IF_ERROR( + c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_shape)); + shape_inference::ShapeHandle summary_shape; + TF_RETURN_IF_ERROR( + c->Merge(summary_shape, c->MakeShape({4}), &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape)); + ShapeHandle output_shape = c->MakeShape({-1}); + for (int i = 0; i < 7; ++i) { + c->set_output(i, output_shape); + } + return Status::OK(); + }); + REGISTER_OP("BoostedTreesCreateEnsemble") .Input("tree_ensemble_handle: resource") .Input("stamp_token: int64") diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py index 92e920a68da..32e47efb64e 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.platform import googletest _INEQUALITY_DEFAULT_LEFT = 'inequality_default_left'.encode('utf-8') +_INEQUALITY_DEFAULT_RIGHT = 'inequality_default_right'.encode('utf-8') class StatsOpsTest(test_util.TensorFlowTestCase): @@ -56,6 +57,98 @@ class StatsOpsTest(test_util.TensorFlowTestCase): ], # feature 1 ] # shape=[num_features, max_splits, num_buckets, 2] + def _get_sparse_stats_summary_for_split(self, stats_summary=None): + if stats_summary is None: + stats_summary = np.asarray(self._get_stats_summary_for_split()) + stats_summary[0][0][1] = np.zeros([2]) + stats_summary[1][0][2] = np.zeros([2]) + stats_summary = np.moveaxis(stats_summary, 0, 1) + slices = stats_summary.nonzero() + values = stats_summary[slices] + indices = np.asarray(slices) + return np.moveaxis(indices, 0, 1), values, stats_summary.shape + + def testCalculateBestSplitsWithoutRegularizationInSparse(self): + # This test uses the same data as dense, but run in sparse kernel and + # make sure the sparse kernel returns same result as dense kernel. + dense_summary = np.asarray([ + [ + [[0., 0.], [.0, .0], [0., 0.], [0., 0.]], # node 0; ignored + [[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1 + [[0., 0.], [-.33, .58], [0., 0.], [.3, .4]], # node 2 + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored + ], # feature 0 + [ + [[0., 0.], [0., 0.], [.0, .0], [0., 0.]], # node 0; ignored + [[0., 0.], [.3, .5], [-.05, .06], [.06, .07]], # node 1 + [[.1, .1], [.2, .3], [-.4, .5], [.07, .08]], # node 2 + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored + ], # feature 1 + ]) # num_features * shape=[max_splits, num_buckets, 2] + node_id_range = [1, 3] + dense_summary = np.moveaxis(dense_summary, 0, 1) + dense_shape = dense_summary.shape + + default_bucket_summary = np.zeros(dense_shape[0:2] + (1, dense_shape[3])) + sparse_summary = np.concatenate((dense_summary, default_bucket_summary), + axis=2) + slices = sparse_summary.nonzero() + summary_values = sparse_summary[slices] + summary_indices = np.asarray(slices) + summary_indices = np.moveaxis(summary_indices, 0, 1) + summary_shape = sparse_summary.shape + + (node_ids, gains, _, _, left_node_contribs, right_node_contribs, + _) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0.0, + l2=0.0, + tree_complexity=0.0, + min_node_weight=0, + logits_dimension=1)) + + self.assertAllEqual([1, 2], node_ids) + self.assertAllClose([0.02823, 0.41184], gains) + self.assertAllClose([-0.6], left_node_contribs[0]) + self.assertAllClose([-0.076923], right_node_contribs[0]) + + def testSparseCalculateBestSplitsWithoutRegularization(self): + node_id_range = [1, 3] + (summary_indices, summary_values, + summary_shape) = self._get_sparse_stats_summary_for_split() + + (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, + right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0.0, + l2=0.0, + tree_complexity=0.0, + min_node_weight=0, + logits_dimension=1)) + self.assertAllEqual([1, 2], node_ids) + self.assertAllClose([0.116495, 0.60429], gains) + self.assertAllEqual([1, 1], thresholds) + self.assertAllEqual([1, 1], feature_dimensions) + # The left node contrib will be later added to the previous node value to + # make the left node value, and the same for right node contrib. + self.assertAllClose([[-0.631579], [-0.770833]], left_node_contribs) + self.assertAllClose([[0.833333], [0.8]], right_node_contribs) + self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) + def testCalculateBestGainsWithoutRegularization(self): """Testing Gain calculation without any regularization.""" with self.cached_session() as sess: @@ -174,6 +267,34 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[-.043478], [-.6]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) + def testSparseCalculateBestSplitsWithL2(self): + node_id_range = [1, 3] + (summary_indices, summary_values, + summary_shape) = self._get_sparse_stats_summary_for_split() + + (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, + right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0.0, + l2=0.1, + tree_complexity=0.0, + min_node_weight=0, + logits_dimension=1)) + self.assertAllEqual([1, 2], node_ids) + self.assertAllClose([0.077414, 0.501868], gains) + self.assertAllEqual([1, 1], feature_dimensions) + self.assertAllEqual([1, 1], thresholds) + # The left node contrib will be later added to the previous node value to + # make the left node value, and the same for right node contrib. + self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs) + self.assertAllClose([[0.3125], [0.666667]], right_node_contribs) + self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT, _INEQUALITY_DEFAULT_LEFT], + split_types) + def testCalculateBestGainsWithL1(self): """Testing Gain calculation with L1.""" with self.cached_session() as sess: @@ -236,6 +357,33 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) + def testSparseCalculateBestSplitsWithL1(self): + node_id_range = [1, 3] + (summary_indices, summary_values, + summary_shape) = self._get_sparse_stats_summary_for_split() + + (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, + right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0.1, + l2=0., + tree_complexity=0.0, + min_node_weight=0, + logits_dimension=1)) + self.assertAllEqual([1, 2], node_ids) + self.assertAllClose([0.048597, 0.331875], gains) + self.assertAllEqual([1, 1], feature_dimensions) + self.assertAllEqual([1, 1], thresholds) + # The left node contrib will be later added to the previous node value to + # make the left node value, and the same for right node contrib. + self.assertAllClose([[-0.45614], [-0.5625]], left_node_contribs) + self.assertAllClose([[0.0], [0.6]], right_node_contribs) + self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) + def testCalculateBestGainsWithTreeComplexity(self): """Testing best gain calculation with tree complexity.""" with self.cached_session() as sess: @@ -300,6 +448,33 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([1, 0], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) + def testSparseCalculateBestSplitsWithTreeComplexity(self): + """Testing best split calculation with tree complexity.""" + node_id_range = [1, 3] + (summary_indices, summary_values, + summary_shape) = self._get_sparse_stats_summary_for_split() + + (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, + right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0., + l2=0.1, + tree_complexity=3., + min_node_weight=0, + logits_dimension=1)) + + self.assertAllEqual([1, 2], node_ids) + self.assertAllClose([-2.922586, -2.498132], gains) + self.assertAllEqual([1, 1], feature_dimensions) + self.assertAllEqual([1, 1], thresholds) + self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs) + self.assertAllClose([[0.3125], [0.666667]], right_node_contribs) + self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) + def testCalculateBestGainsWithMinNodeWeight(self): """Testing Gain calculation with min node weight.""" with self.cached_session() as sess: @@ -393,8 +568,59 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeaturePossible(self): - """Testing Gain calculation with min node weight and no split.""" + def testSparseCalculateBestSplitsWithMinNodeWeight(self): + """Testing best split calculation with min node weight.""" + node_id_range = [1, 3] # node 1 through 2 will be processed. + stats_summary = np.asarray([ + [ + [[0., 0.], [.0, .0], [0., 0.], [0., 0.]], # node 0; ignored + [[0., 0.], [.15, .36], [.06, .61], [.1, .2]], # node 1 + [[0., 0.], [-.33, .68], [0., 0.], [.3, .4]], # node 2 + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored + ], # feature 0 + [ + [[0., 0.], [0., 0.], [.0, .0], [0., 0.]], # node 0; ignored + [[0., 0.], [-.05, .6], [.3, .5], [.06, .07]], # node 1 + [[.1, 1.], [.2, -.05], [-.4, .05], [.07, .08]], # node 2 + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored + ], # feature 1 + ]) # num_features * shape=[max_splits, num_buckets, 2] + # reshape to [max_splits, num_features, num_buckets, 2] + stats_summary = np.moveaxis(stats_summary, 0, 1) + + (summary_indices, summary_values, + summary_shape) = self._get_sparse_stats_summary_for_split(stats_summary) + + (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, + right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0., + l2=0., + tree_complexity=0., + min_node_weight=1, + logits_dimension=1)) + + self.assertAllEqual([1, 2], node_ids) + self.assertAllClose([0.149398, 3.332079], gains) + self.assertAllEqual([1, 1], thresholds) + self.assertAllClose([[0.083333], [-0.359223]], left_node_contribs) + self.assertAllClose([[-0.631579], [7.999998]], right_node_contribs) + self.assertAllEqual([1, 1], feature_dimensions) + self.assertAllEqual([_INEQUALITY_DEFAULT_RIGHT, _INEQUALITY_DEFAULT_LEFT], + split_types) + + def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self): + """Testing Gain calculation without any regularization.""" with self.cached_session() as sess: max_splits = 7 node_id_range = [1, 3] # node 1 through 2 will be processed. @@ -497,6 +723,63 @@ class StatsOpsTest(test_util.TensorFlowTestCase): logits_dimension=1) self.assertAllEqual([], node_ids) + def testSparseCalculateBestSplitsWithMinNodeWeightNoSplitOnFeature(self): + """Testing best split calculation with min node weight and no split.""" + node_id_range = [1, 3] # node 1 through 2 will be processed. + stats_summary = np.asarray([ + [ + [[0., 0.], [.0, .0], [0., 0.], [0., 0.]], # node 0; ignored + [[0., 0.], [.15, .36], [.06, .7], [.1, .2]], # node 1 + [[0., 0.], [-.33, .068], [0., 0.], [.3, .04]], # node 2 + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored + ], # feature 0 + [ + [[0., 0.], [0., 0.], [.0, .0], [0., 0.]], # node 0; ignored + [[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1 + [[.1, .1], [.2, .03], [-.4, .05], [.07, .08]], # node 2 + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored + [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored + ], # feature 1 + ]) # num_features * shape=[max_splits, num_buckets, 2] + # reshape to [max_splits, num_features, num_buckets, 2] + stats_summary = np.moveaxis(stats_summary, 0, 1) + (summary_indices, summary_values, + summary_shape) = self._get_sparse_stats_summary_for_split(stats_summary) + + (node_ids, _, _, _, _, _, _) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0., + l2=0., + tree_complexity=0., + min_node_weight=1, + logits_dimension=1)) + + # We can't split either of the nodes on the first feature + self.assertAllEqual([1], node_ids) + + # Now check when we can't split on any feature + (node_ids, _, _, _, _, _, _) = self.evaluate( + boosted_trees_ops.sparse_calculate_best_feature_split( + node_id_range, + summary_indices, + summary_values, + summary_shape, + l1=0., + l2=0., + tree_complexity=0., + min_node_weight=10, + logits_dimension=1)) + self.assertAllEqual([], node_ids) + @test_util.run_deprecated_v1 def testMakeStatsSummarySimple(self): """Simple test for MakeStatsSummary.""" diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py index 30b6bc16938..63b26e913d2 100644 --- a/tensorflow/python/ops/boosted_trees_ops.py +++ b/tensorflow/python/ops/boosted_trees_ops.py @@ -40,6 +40,7 @@ from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_s from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_aggregate_stats +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_calculate_best_feature_split as sparse_calculate_best_feature_split from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 16725f81b8c..6d9225cb768 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -532,6 +532,10 @@ tf_module { name: "BoostedTreesSparseAggregateStats" argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature_indices\', \'feature_values\', \'feature_shape\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "BoostedTreesSparseCalculateBestFeatureSplit" + argspec: "args=[\'node_id_range\', \'stats_summary_indices\', \'stats_summary_values\', \'stats_summary_shape\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], " + } member_method { name: "BoostedTreesTrainingPredict" argspec: "args=[\'tree_ensemble_handle\', \'cached_tree_ids\', \'cached_node_ids\', \'bucketized_features\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 16725f81b8c..6d9225cb768 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -532,6 +532,10 @@ tf_module { name: "BoostedTreesSparseAggregateStats" argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature_indices\', \'feature_values\', \'feature_shape\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "BoostedTreesSparseCalculateBestFeatureSplit" + argspec: "args=[\'node_id_range\', \'stats_summary_indices\', \'stats_summary_values\', \'stats_summary_shape\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], " + } member_method { name: "BoostedTreesTrainingPredict" argspec: "args=[\'tree_ensemble_handle\', \'cached_tree_ids\', \'cached_node_ids\', \'bucketized_features\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "