diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt new file mode 100644 index 00000000000..2bbaba26257 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestFeatureSplitV2.pbtxt @@ -0,0 +1,124 @@ +op { + graph_op_name: "BoostedTreesCalculateBestFeatureSplitV2" + visibility: HIDDEN + in_arg { + name: "node_id_range" + description: <; using ConstVectorMap = Eigen::Map; using VectorMap = Eigen::Map; -// V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOp is V2. +constexpr char kInequalitySplit[] = "inequality"; +constexpr char kEqualitySplit[] = "equality"; + +// V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOpV2 is V2. class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel { public: explicit BoostedTreesCalculateBestGainsPerFeatureOp( @@ -227,7 +230,7 @@ REGISTER_KERNEL_BUILDER( Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU), BoostedTreesCalculateBestGainsPerFeatureOp); -// V2 Op. +// Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2. class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { public: explicit BoostedTreesCalculateBestFeatureSplitOp( @@ -545,11 +548,394 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel { string split_type_; }; -// v2 op that supports multi-class. +// Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2. REGISTER_KERNEL_BUILDER( Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU), BoostedTreesCalculateBestFeatureSplitOp); +// V2 Op. +class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel { + public: + explicit BoostedTreesCalculateBestFeatureSplitV2( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_)); + OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_)); + } + + void Compute(OpKernelContext* const context) override { + // node_id_range + const Tensor* node_id_range_t; + OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t)); + const auto node_id_range = node_id_range_t->vec(); + const int32 node_id_first = node_id_range(0); // Inclusive. + const int32 node_id_last = node_id_range(1); // Exclusive. + + // Get stats_summaries_list. + OpInputList stats_summaries_list; + OP_REQUIRES_OK(context, context->input_list("stats_summaries_list", + &stats_summaries_list)); + + // Infer dimensions of a stats_summary. + DCHECK_GT(stats_summaries_list.size(), 0); + const int32 feature_dims = stats_summaries_list[0].dim_size(1); + // The last bucket is for default/missing value. + const int32 num_buckets = stats_summaries_list[0].dim_size(2) - 1; + const int32 logits_dim = logits_dim_; + const int32 hessian_dim = stats_summaries_list[0].dim_size(3) - logits_dim; + DCHECK_GT(hessian_dim, 0); + DCHECK_LE(hessian_dim, logits_dim * logits_dim); + + // Vector of stats_summaries; each element is stats for feature of shape + // [max_splits, feature_dim, num_buckets, logits_dim + hessian_dim]. + std::vector::ConstTensor> stats_summaries; + DCHECK_EQ(stats_summaries_list.size(), num_features_); + stats_summaries.reserve(num_features_); + for (const auto& tensor : stats_summaries_list) { + stats_summaries.emplace_back(tensor.tensor()); + } + + // Split types. + const Tensor* split_types_t; + OP_REQUIRES_OK(context, context->input("split_types", &split_types_t)); + const auto split_types = split_types_t->vec(); + DCHECK_EQ(split_types.size(), num_features_); + // Validate. + for (int i = 0; i < num_features_; ++i) { + if (!(split_types(i) == kInequalitySplit || + split_types(i) == kEqualitySplit)) { + OP_REQUIRES_OK( + context, + errors::Aborted( + "Operation received an exception: Incorrect split type")); + } + } + // Feature ids. + const Tensor* candidate_feature_ids_t; + OP_REQUIRES_OK(context, context->input("candidate_feature_ids", + &candidate_feature_ids_t)); + const auto candidate_feature_ids = candidate_feature_ids_t->vec(); + DCHECK_EQ(candidate_feature_ids.size(), num_features_); + + // L1, L2, tree_complexity, min_node_weight. + const Tensor* l1_t; + OP_REQUIRES_OK(context, context->input("l1", &l1_t)); + const auto l1 = l1_t->scalar()(); + DCHECK_GE(l1, 0); + if (logits_dim_ > 1) { + // Multi-class L1 regularization not supported yet. + DCHECK_EQ(l1, 0); + } + const Tensor* l2_t; + OP_REQUIRES_OK(context, context->input("l2", &l2_t)); + const auto l2 = l2_t->scalar()(); + DCHECK_GE(l2, 0); + const Tensor* tree_complexity_t; + OP_REQUIRES_OK(context, + context->input("tree_complexity", &tree_complexity_t)); + const auto tree_complexity = tree_complexity_t->scalar()(); + const Tensor* min_node_weight_t; + OP_REQUIRES_OK(context, + context->input("min_node_weight", &min_node_weight_t)); + const auto min_node_weight = min_node_weight_t->scalar()(); + + std::vector output_node_ids; + std::vector output_gains; + std::vector output_feature_ids; + std::vector output_feature_dimensions; + std::vector output_thresholds; + std::vector output_left_node_contribs; + std::vector output_right_node_contribs; + std::vector output_split_types; + + // TODO(tanzheny) parallelize the computation. + // Iterate each node and find the best gain per node. + float parent_gain; + for (int32 node_id = node_id_first; node_id < node_id_last; ++node_id) { + float best_gain = std::numeric_limits::lowest(); + int32 best_bucket; + int32 best_f_id; + int32 best_f_dim; + string best_split_type; + Eigen::VectorXf best_contrib_for_left(logits_dim); + Eigen::VectorXf best_contrib_for_right(logits_dim); + + // Sum of gradient and hessian. Compute parent gain using first feature. + ConstMatrixMap stats_mat(&stats_summaries[0](node_id, 0, 0, 0), + num_buckets + 1, // Including default bucket. + logits_dim + hessian_dim); + const Eigen::VectorXf total_grad = + stats_mat.leftCols(logits_dim).colwise().sum(); + const Eigen::VectorXf total_hess = + stats_mat.rightCols(hessian_dim).colwise().sum(); + if (total_hess.norm() < min_node_weight) { + continue; + } + Eigen::VectorXf unused(logits_dim); + CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused, + &parent_gain); + for (int f_idx = 0; f_idx < num_features_; ++f_idx) { + const string split_type = split_types(f_idx); + TTypes::ConstTensor stats_summary = stats_summaries[f_idx]; + float f_best_gain = std::numeric_limits::lowest(); + int32 f_best_bucket; + int32 f_best_f_dim; + string f_best_split_type; + Eigen::VectorXf f_best_contrib_for_left(logits_dim); + Eigen::VectorXf f_best_contrib_for_right(logits_dim); + + if (split_type == kInequalitySplit) { + CalculateBestInequalitySplit( + stats_summary, node_id, feature_dims, logits_dim, hessian_dim, + num_buckets, min_node_weight, l1, l2, &f_best_gain, + &f_best_bucket, &f_best_f_dim, &f_best_split_type, + &f_best_contrib_for_left, &f_best_contrib_for_right); + } else { + CalculateBestEqualitySplit( + stats_summary, total_grad, total_hess, node_id, feature_dims, + logits_dim, hessian_dim, num_buckets, l1, l2, &f_best_gain, + &f_best_bucket, &f_best_f_dim, &f_best_split_type, + &f_best_contrib_for_left, &f_best_contrib_for_right); + } + if (f_best_gain > best_gain) { + best_gain = f_best_gain; + best_f_id = candidate_feature_ids(f_idx); + best_f_dim = f_best_f_dim; + best_split_type = f_best_split_type; + best_bucket = f_best_bucket; + best_contrib_for_left = f_best_contrib_for_left; + best_contrib_for_right = f_best_contrib_for_right; + } + } // For feature id. + if (best_gain == std::numeric_limits::lowest()) { + // Do not add the node if no split is found. + continue; + } + output_node_ids.push_back(node_id); + // Remove the parent gain for the parent node. + output_gains.push_back(best_gain - parent_gain); + output_feature_ids.push_back(best_f_id); + output_feature_dimensions.push_back(best_f_dim); + // Default direction is fixed for dense splits. + // TODO(tanzheny) account for default values. + output_split_types.push_back(best_split_type); + output_thresholds.push_back(best_bucket); + output_left_node_contribs.push_back(best_contrib_for_left); + output_right_node_contribs.push_back(best_contrib_for_right); + } // for node id. + const int num_nodes = output_node_ids.size(); + // output_node_ids + Tensor* output_node_ids_t = nullptr; + OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes}, + &output_node_ids_t)); + auto output_node_ids_vec = output_node_ids_t->vec(); + + // output_gains + Tensor* output_gains_t; + OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes}, + &output_gains_t)); + auto output_gains_vec = output_gains_t->vec(); + + // output_feature_ids + Tensor* output_features_ids_t; + OP_REQUIRES_OK(context, context->allocate_output("feature_ids", {num_nodes}, + &output_features_ids_t)); + auto output_features_vec = output_features_ids_t->vec(); + + // output_feature_dimensions + Tensor* output_feature_dimension_t; + OP_REQUIRES_OK(context, + context->allocate_output("feature_dimensions", {num_nodes}, + &output_feature_dimension_t)); + auto output_feature_dimensions_vec = + output_feature_dimension_t->vec(); + + // output_thresholds + Tensor* output_thresholds_t; + OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes}, + &output_thresholds_t)); + auto output_thresholds_vec = output_thresholds_t->vec(); + + // output_left_node_contribs + Tensor* output_left_node_contribs_t; + OP_REQUIRES_OK(context, context->allocate_output( + "left_node_contribs", {num_nodes, logits_dim}, + &output_left_node_contribs_t)); + auto output_left_node_contribs_matrix = + output_left_node_contribs_t->matrix(); + + // output_right_node_contribs + Tensor* output_right_node_contribs_t; + OP_REQUIRES_OK(context, context->allocate_output( + "right_node_contribs", {num_nodes, logits_dim}, + &output_right_node_contribs_t)); + auto output_right_node_contribs_matrix = + output_right_node_contribs_t->matrix(); + + // split type + Tensor* output_split_types_t; + OP_REQUIRES_OK( + context, context->allocate_output("split_with_default_directions", + {num_nodes}, &output_split_types_t)); + auto output_split_types_vec = output_split_types_t->vec(); + + // Sets output tensors from vectors. + for (int i = 0; i < num_nodes; ++i) { + output_node_ids_vec(i) = output_node_ids[i]; + output_features_vec(i) = output_feature_ids[i]; + // Adjust the gains to penalize by tree complexity. + output_gains_vec(i) = output_gains[i] - tree_complexity; + output_feature_dimensions_vec(i) = output_feature_dimensions[i]; + output_thresholds_vec(i) = output_thresholds[i]; + for (int j = 0; j < logits_dim; ++j) { + output_left_node_contribs_matrix(i, j) = + output_left_node_contribs[i][j]; + output_right_node_contribs_matrix(i, j) = + output_right_node_contribs[i][j]; + } + output_split_types_vec(i) = output_split_types[i]; + } + } + + private: + // TODO(crawles): Simplify inequality path just like equality b/138329196 + // Currently this is not simplify-able due to numerical instability in math + // i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g) + // It caused gain to be Inf when g is approaching 0 but not exactly 0 while + // there is no regularization. + // Calculate the best inequality split per node. + void CalculateBestInequalitySplit( + TTypes::ConstTensor stats_summary, const int32 node_id, + const int32 feature_dims, const int32 logits_dim, const int32 hessian_dim, + const int32 num_buckets, const float min_node_weight, const float l1, + const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim, + string* best_split_type, Eigen::VectorXf* best_contrib_for_left, + Eigen::VectorXf* best_contrib_for_right) { + std::vector cum_grad; + std::vector cum_hess; + // get all cumulative gradients including default bucket. + cum_grad.reserve(num_buckets); + cum_hess.reserve(num_buckets); + + for (int f_dim = 0; f_dim < feature_dims; ++f_dim) { + ConstVectorMap default_stats_vec( + &stats_summary(node_id, f_dim, num_buckets, 0), + logits_dim + hessian_dim); + Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim); + Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim); + cum_grad.clear(); + cum_hess.clear(); + Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim); + Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim); + // sum all the gradients including default bucket. + for (int bucket = 0; bucket <= num_buckets; ++bucket) { + for (int i = 0; i < logits_dim; ++i) { + total_grad[i] += stats_summary(node_id, f_dim, bucket, i); + } + for (int i = 0; i < hessian_dim; ++i) { + // Full hessian. + total_hess[i] += + stats_summary(node_id, f_dim, bucket, logits_dim + i); + } + if (bucket < num_buckets) { + cum_grad.push_back(total_grad); + cum_hess.push_back(total_hess); + } + } + const string kInequalityDefaultLeft = + boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_LEFT); + const string kInequalityDefaultRight = + boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::INEQUALITY_DEFAULT_RIGHT); + + // Iterate from left to right, excluding default bucket. + for (int bucket = 0; bucket < num_buckets; ++bucket) { + // default value goes to left node. + const Eigen::VectorXf total_left_grad = + cum_grad[bucket] + missing_bucket_grad; + const Eigen::VectorXf total_left_hess = + cum_hess[bucket] + missing_bucket_hess; + MaybeUpdateBestSplit( + total_left_grad, total_grad - total_left_grad, total_left_hess, + total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2, + kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim, + best_split_type, best_contrib_for_left, best_contrib_for_right); + // default value goes to right node. + MaybeUpdateBestSplit( + cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket], + total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2, + kInequalityDefaultRight, best_gain, best_bucket, best_f_dim, + best_split_type, best_contrib_for_left, best_contrib_for_right); + } // for bucket + } + } + + // Calculate the best equality split per node. + void CalculateBestEqualitySplit( + TTypes::ConstTensor stats_summary, + const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess, + const int32 node_id, const int32 feature_dims, const int32 logits_dim, + const int32 hessian_dim, const int32 num_buckets, const float l1, + const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim, + string* best_split_type, Eigen::VectorXf* best_contrib_for_left, + Eigen::VectorXf* best_contrib_for_right) { + const string kEqualityDefaultRight = + boosted_trees::SplitTypeWithDefault_Name( + boosted_trees::EQUALITY_DEFAULT_RIGHT); + for (int f_dim = 0; f_dim < feature_dims; ++f_dim) { + for (int bucket = 0; bucket < num_buckets; ++bucket) { + ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0), + logits_dim + hessian_dim); + Eigen::VectorXf curr_grad = stats_vec.head(logits_dim); + Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim); + MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess, + total_hess - curr_hess, logits_dim, bucket, f_dim, + l1, l2, kEqualityDefaultRight, best_gain, + best_bucket, best_f_dim, best_split_type, + best_contrib_for_left, best_contrib_for_right); + } + } + } + + void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left, + const Eigen::VectorXf& grad_for_right, + const Eigen::VectorXf& hess_for_left, + const Eigen::VectorXf& hess_for_right, + const int32 logits_dim, const int32 bucket, + const int32 f_dim, const float l1, const float l2, + const string split_type, float* best_gain, + int32* best_bucket, int32* best_f_dim, + string* best_split_type, + Eigen::VectorXf* best_contrib_for_left, + Eigen::VectorXf* best_contrib_for_right) { + // Left child. + Eigen::VectorXf contrib_for_left(logits_dim); + float gain_for_left; + CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2, + &contrib_for_left, &gain_for_left); + Eigen::VectorXf contrib_for_right(logits_dim); + float gain_for_right; + CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2, + &contrib_for_right, &gain_for_right); + if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) { + *best_gain = gain_for_left + gain_for_right; + *best_bucket = bucket; + *best_f_dim = f_dim; + *best_contrib_for_left = contrib_for_left; + *best_contrib_for_right = contrib_for_right; + *best_split_type = split_type; + } + } + int num_features_; + int logits_dim_; +}; + +// v2 op that supports multi-class. +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesCalculateBestFeatureSplitV2").Device(DEVICE_CPU), + BoostedTreesCalculateBestFeatureSplitV2); + // Map from bucket id to vector of statistics. typedef std::map> BucketMap; typedef BucketMap::iterator BucketMapIterator; diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index 80de75d7fc5..639a753b5dc 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -141,6 +141,74 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit") return Status::OK(); }); +REGISTER_OP("BoostedTreesCalculateBestFeatureSplitV2") + .Input("node_id_range: int32") + .Input("stats_summaries_list: num_features * float32") + .Input("split_types: string") + .Input("candidate_feature_ids: int32") + .Input("l1: float") + .Input("l2: float") + .Input("tree_complexity: float") + .Input("min_node_weight: float") + .Attr("num_features: int >= 1") // not passed but populated automatically. + .Attr("logits_dimension: int >= 1") + .Output("node_ids: int32") + .Output("gains: float32") + .Output("feature_ids: int32") + .Output("feature_dimensions: int32") + .Output("thresholds: int32") + .Output("left_node_contribs: float32") + .Output("right_node_contribs: float32") + .Output("split_with_default_directions: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + // Attributes. + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + int logits_dimension; + TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); + // Inputs. + shape_inference::ShapeHandle unused_shape; + // node id range is rank 1 with 2 values. + shape_inference::ShapeHandle node_id_range_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape)); + TF_RETURN_IF_ERROR( + c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape)); + // Stats summary validation. + shape_inference::ShapeHandle summary_shape_base; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &summary_shape_base)); + // All stats summary entries are of the same shape. + for (int i = 1; i < num_features; ++i) { + shape_inference::ShapeHandle summary_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 4, &summary_shape)); + TF_RETURN_IF_ERROR( + c->Merge(summary_shape_base, summary_shape, &unused_shape)); + } + // Validate rank 1 split_types. + TF_RETURN_IF_ERROR( + c->WithRank(c->input(1 + num_features), 1, &unused_shape)); + // Validate rank 1 feature_ids. + TF_RETURN_IF_ERROR( + c->WithRank(c->input(2 + num_features), 1, &unused_shape)); + // Validate rank 0: l1, l2, tree_complexity, min_node_weight. + for (int i = 0; i < 4; ++i) { + TF_RETURN_IF_ERROR( + c->WithRank(c->input(3 + num_features + i), 0, &unused_shape)); + } + // Output shapes. + ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); + c->set_output(0, rank_1_output_shape); + c->set_output(1, rank_1_output_shape); + c->set_output(2, rank_1_output_shape); + c->set_output(3, rank_1_output_shape); + c->set_output(4, rank_1_output_shape); + ShapeHandle contribs_output_shape = + c->MakeShape({c->UnknownDim(), logits_dimension}); + c->set_output(5, contribs_output_shape); + c->set_output(6, contribs_output_shape); + c->set_output(7, rank_1_output_shape); + return Status::OK(); + }); + REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit") .Input("node_id_range: int32") .Input("stats_summary_indices: int32") diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py index 402c6f041e0..c5f58f1f6b2 100644 --- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py +++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py @@ -46,6 +46,24 @@ class StatsOpsTest(test_util.TensorFlowTestCase): axis=2) return stats_summary + def add_f_dim_and_append_zeros(self, stats_summaries): + """Transform a list of stats summaries, adding a feature dimension. + + The input shape is a list of arrays of shape [max_splits, num_buckets, + logits+hess dim]. This transformation returns a list of arrays of shape + [max_splits, 1, num_buckets + 1, logits+hess dim]. + + Args: + stats_summaries: a list of numpy arrays. + + Returns: + A list of numpy arrays. + """ + return [ + self._append_zeros_for_default_bucket(np.expand_dims(feature, axis=1)) + for feature in stats_summaries + ] + def _get_stats_summary_for_split(self): return [ [ @@ -160,7 +178,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[0.833333], [0.8]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestGainsWithoutRegularization(self): + def testCalculateBestGainsWithoutRegularization_v1_op(self): """Testing Gain calculation without any regularization.""" with self.cached_session() as sess: max_splits = 7 @@ -189,19 +207,40 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[[-.592593], [-.75]], [[-.076923], [.568966]]], self.evaluate(right_node_contribs_list)) - def testCalculateBestMultiDimFeatureSplitsWithoutRegularization(self): + def testCalculateBestFeaturesInvalidSplitType_v2_op(self): """Testing best split calculation without any regularization.""" + candidate_feature_ids = [9, 12] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) - stats_summary = self._append_zeros_for_default_bucket(stats_summary) + stats_summaries = self._get_stats_summary_for_split() + stats_summaries = self.add_f_dim_and_append_zeros(stats_summaries) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( + with self.assertRaisesRegexp(Exception, 'Incorrect split type'): + self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, + stats_summaries, + split_types=['INVALID'] * len(candidate_feature_ids), + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.0, + tree_complexity=0.0, + min_node_weight=0, + logits_dimension=1)) + + def testCalculateBestFeaturesWithoutRegularization_v2_op(self): + """Testing best split calculation without any regularization.""" + candidate_feature_ids = [9, 12] + node_id_range = [1, 3] # node 1 through 2 will be processed. + stats_summaries = self._get_stats_summary_for_split() + stats_summaries = self.add_f_dim_and_append_zeros(stats_summaries) + + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( node_id_range, - stats_summary, + stats_summaries, + split_types=['inequality'] * len(candidate_feature_ids), + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -209,10 +248,47 @@ class StatsOpsTest(test_util.TensorFlowTestCase): logits_dimension=1)) # Get same result as v1 op (CalculateBestGainsPerFeature), and find the - # feature dimension that has the best gain. + # feature_id and dimension that has the best gain per node. self.assertAllEqual([1, 2], node_ids) self.assertAllClose([0.02823, 0.41184], gains) self.assertAllEqual([1, 1], thresholds) + self.assertAllEqual([12, 9], feature_ids) + f_dim = 0 # Both features only have one dimension. + self.assertAllEqual([f_dim] * 2, feature_dimensions) + # The left node contrib will be later added to the previous node value to + # make the left node value, and the same for right node contrib. + self.assertAllClose([[-.6], [.568966]], left_node_contribs) + self.assertAllClose([[-.076923], [-.75]], right_node_contribs) + self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) + + def testCalculateBestMultiDimFeatureSplitsWithoutRegularization_v2_op(self): + """Testing best split without any regularization for a multi-dim feature.""" + candidate_feature_ids = [4] + node_id_range = [1, 3] # node 1 through 2 will be processed. + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) + stats_summary = self._append_zeros_for_default_bucket(stats_summary) + + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.0, + tree_complexity=0.0, + min_node_weight=0, + logits_dimension=1)) + + # Get same result as v1 op (CalculateBestGainsPerFeature), and find the + # feature_id and dimension that has the best gain per node. + self.assertAllEqual([1, 2], node_ids) + self.assertAllClose([0.02823, 0.41184], gains) + self.assertAllEqual([1, 1], thresholds) + self.assertAllEqual([4, 4], feature_ids) self.assertAllEqual([1, 0], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. @@ -220,18 +296,22 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[-.076923], [-.75]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestMultiDimFeatureSplitWMissingValuesWORegularization(self): + def testCalculateBestMultiDimFeatureSplitWMissingValuesWORegularization_v2_op( + self): """Testing best split calculation without any regularization.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -242,39 +322,44 @@ class StatsOpsTest(test_util.TensorFlowTestCase): # feature dimension that has the best gain. self.assertAllEqual([1, 2], node_ids) self.assertAllClose([0.116495, 0.60429], gains) - self.assertAllEqual([1, 1], thresholds) + self.assertAllEqual([4, 4], feature_ids) self.assertAllEqual([1, 1], feature_dimensions) + self.assertAllEqual([1, 1], thresholds) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. self.assertAllClose([[-0.631579], [-0.770833]], left_node_contribs) self.assertAllClose([[0.833333], [0.8]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestMultiDimFeatureEqualitySplitsWithoutRegularization(self): + def testCalculateBestMultiDimFeatureEqualitySplitsWithoutRegularization_v2_op( + self): """Testing best split calculation without any regularization.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['equality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, min_node_weight=0, - logits_dimension=1, - split_type='equality')) + logits_dimension=1)) self.assertAllEqual([1, 2], node_ids) # 0.116495 = (-0.05)^2/0.06 + 0.36^2/0.57 - 0.31^2/0.63 # 0.60429 = (-0.4)^2/0.5 + 0.37^2/0.48 - 0.03^2/0.98 self.assertAllClose([0.116495, 0.60429], gains) - self.assertAllEqual([2, 2], thresholds) + self.assertAllEqual([4, 4], feature_ids) self.assertAllEqual([1, 1], feature_dimensions) + self.assertAllEqual([2, 2], thresholds) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. # left contrib 0.83 = 0.05/0.06, 0.8 = 0.4/0.5 @@ -283,7 +368,48 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[-0.631579], [-0.770833]], right_node_contribs) self.assertAllEqual([_EQUALITY_DEFAULT_RIGHT] * 2, split_types) - def testCalculateBestGainsWithL2(self): + def testCalculateBestMultiDimFeatureMixedSplitTypeWithoutRegularization_v2_op( + self): + """Testing best split calculation without any regularization.""" + candidate_feature_ids = [9, 12] + node_id_range = [1, 3] # node 1 through 2 will be processed. + stats_summaries = self._get_stats_summary_for_split() + # Add in feature dimension. + stats_summaries = [ + np.expand_dims(feature, axis=1) for feature in stats_summaries + ] + + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, + stats_summaries, + split_types=['inequality', 'equality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.0, + tree_complexity=0.0, + min_node_weight=0, + logits_dimension=1)) + + self.assertAllEqual([1, 2], node_ids) + # 0.116495 = (-0.05)^2/0.06 + 0.36^2/0.57 - 0.31^2/0.63 + # 0.60429 = (-0.4)^2/0.5 + 0.37^2/0.48 - 0.03^2/0.98 + self.assertAllClose([0.116495, 0.60429], gains) + self.assertAllEqual([12, 12], feature_ids) + f_dim = 0 # Both features only have one dimension. + self.assertAllEqual([f_dim, f_dim], feature_dimensions) + self.assertAllEqual([2, 2], thresholds) + # Same result as equality only test, as feature_1 is chose for both nodes. + # left contrib 0.83 = 0.05/0.06, 0.8 = 0.4/0.5 + self.assertAllClose([[0.833333], [.8]], left_node_contribs) + # right contrib -0.6315 = -0.36/0.57, -0.7708 = -0.37/0.48 + self.assertAllClose([[-0.631579], [-0.770833]], right_node_contribs) + # Feature 1 is inequality. + self.assertAllEqual([_EQUALITY_DEFAULT_RIGHT, _EQUALITY_DEFAULT_RIGHT], + split_types) + + def testCalculateBestGainsWithL2_v1_op(self): """Testing Gain calculation with L2.""" with self.cached_session() as sess: max_splits = 7 @@ -312,19 +438,22 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]], self.evaluate(right_node_contribs_list)) - def testCalculateMultiDimBestFeatureSplitsWithL2(self): + def testCalculateMultiDimBestFeatureSplitsWithL2_v2_op(self): """Testing best split calculation with L2.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) stats_summary = self._append_zeros_for_default_bucket(stats_summary) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.1, tree_complexity=0.0, @@ -334,27 +463,31 @@ class StatsOpsTest(test_util.TensorFlowTestCase): # Get same result as v1 op (CalculateBestGainsPerFeature), and find the # feature dimension that has the best gain. self.assertAllEqual([1, 2], node_ids) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 0], feature_dimensions) self.assertAllClose([0.01879096, 0.33931375], gains) self.assertAllEqual([1, 1], thresholds) - self.assertAllEqual([1, 0], feature_dimensions) # # The left node contrib will be later added to the previous node value to # # make the left node value, and the same for right node contrib. self.assertAllClose([[-.5], [.485294]], left_node_contribs) self.assertAllClose([[-.043478], [-.6]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateMultiDimBestFeatureSplitsWithMissingValuesL2(self): + def testCalculateMultiDimBestFeatureSplitsWithMissingValuesL2_v2_op(self): """Testing best split calculation with L2.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.1, tree_complexity=0.0, @@ -364,40 +497,44 @@ class StatsOpsTest(test_util.TensorFlowTestCase): # Get same result as v1 op (CalculateBestGainsPerFeature), and find the # feature dimension that has the best gain. self.assertAllEqual([1, 2], node_ids) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 1], feature_dimensions) self.assertAllClose([0.077414, 0.501868], gains) self.assertAllEqual([1, 1], thresholds) - self.assertAllEqual([1, 1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs) self.assertAllClose([[0.3125], [0.666667]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateMultiDimBestFeatureEqualitySplitsWithL2(self): + def testCalculateMultiDimBestFeatureEqualitySplitsWithL2_v2_op(self): """Testing best split calculation with L2.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['equality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.1, tree_complexity=0.0, min_node_weight=0, - logits_dimension=1, - split_type='equality')) + logits_dimension=1)) self.assertAllEqual([1, 2], node_ids) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 1], feature_dimensions) # 0.077414 = 0.05^2/0.16 + 0.36^2/0.67 - 0.31^2/0.73 # 0.501868 = 0.4^2/0.6 + 0.37^2/0.58 - 0.03^2/1.08 self.assertAllClose([0.077414, 0.501868], gains) self.assertAllEqual([2, 2], thresholds) - self.assertAllEqual([1, 1], feature_dimensions) # # The left node contrib will be later added to the previous node value to # # make the left node value, and the same for right node contrib. # left contrib 0.3125 = 0.05/0.16, 0.6667 = 0.4/0.6 @@ -434,7 +571,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT, _INEQUALITY_DEFAULT_LEFT], split_types) - def testCalculateBestGainsWithL1(self): + def testCalculateBestGainsWithL1_v1_op(self): """Testing Gain calculation with L1.""" with self.cached_session() as sess: max_splits = 7 @@ -466,22 +603,24 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[0.0, 0.191207], [0.01, 0.191207]], self.evaluate(gains_list)) - def testCalculateBestMultiDimFeatureSplitsWithL1(self): + def testCalculateBestMultiDimFeatureSplitsWithL1_v2_op(self): """Testing best split calculation with L1.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) stats_summary = self._append_zeros_for_default_bucket(stats_summary) - l1 = 0.1 - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=l1, - l2=0., + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.1, + l2=0.0, tree_complexity=0.0, min_node_weight=0, logits_dimension=1)) @@ -489,29 +628,32 @@ class StatsOpsTest(test_util.TensorFlowTestCase): # Get same result as v1 op (CalculateBestGainsPerFeature), and find the # feature dimension that has the best gain. self.assertAllEqual([1, 2], node_ids) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 1], feature_dimensions) # Gain should also include an adjustment of the gradient by l1. self.assertAllClose([0.01, 0.191207], gains) self.assertAllEqual([1, 1], thresholds) self.assertAllClose([[-0.4], [-0.5]], left_node_contribs) self.assertAllClose([[0.], [0.396552]], right_node_contribs) - self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestMultiDimFeatureSplitsWithMissingValuesL1(self): + def testCalculateBestMultiDimFeatureSplitsWithMissingValuesL1_v2_op(self): """Testing best split calculation with L1.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) - l1 = 0.1 - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=l1, - l2=0., + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.1, + l2=0.0, tree_complexity=0.0, min_node_weight=0, logits_dimension=1)) @@ -519,6 +661,8 @@ class StatsOpsTest(test_util.TensorFlowTestCase): # Get same result as v1 op (CalculateBestGainsPerFeature), and find the # feature dimension that has the best gain. self.assertAllEqual([1, 2], node_ids) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 1], feature_dimensions) # Gain should also include an adjustment of the gradient by l1. # (0.36-0.1)^2/0.57 + 0 - (0.31-0.1)^2/0.63 = 0.048597 # (0.37-0.1)^2/0.48 + (-0.4+0.1)^2/0.5 = 0.331875 @@ -529,35 +673,37 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[-0.45614], [-0.5625]], left_node_contribs) # -(-0.4+0.1)/0.5 = 0.6 self.assertAllClose([[0.], [0.6]], right_node_contribs) - self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestMultiDimFeatureEqualitySplitsWithL1(self): + def testCalculateBestMultiDimFeatureEqualitySplitsWithL1_v2_op(self): """Testing best split calculation with L1.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) + stats_summary = self._append_zeros_for_default_bucket(stats_summary) - l1 = 0.1 - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=l1, - l2=0., + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['equality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.1, + l2=0.0, tree_complexity=0.0, min_node_weight=0, - logits_dimension=1, - split_type='equality')) + logits_dimension=1)) self.assertAllEqual([1, 2], node_ids) # 0.048597 = 0 + 0.26^2/0.57 - 0.21^2/0.63 # 0.501868 = 0.3^2/0.5 + 0.27^2/0.48 - 0 self.assertAllClose([0.048597, 0.331875], gains) - self.assertAllEqual([2, 2], thresholds) + self.assertAllEqual([4, 4], feature_ids) self.assertAllEqual([1, 1], feature_dimensions) + self.assertAllEqual([2, 2], thresholds) # # The left node contrib will be later added to the previous node value to # # make the left node value, and the same for right node contrib. # left contrib 0 (-0.05>-0.1), 0.6 = 0.3/0.5 @@ -593,7 +739,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[0.0], [0.6]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestGainsWithTreeComplexity(self): + def testCalculateBestGainsWithTreeComplexity_v1_op(self): """Testing best gain calculation with tree complexity.""" with self.cached_session() as sess: max_splits = 7 @@ -626,24 +772,25 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]], self.evaluate(right_node_contribs_list)) - def testCalculateBestMultiDimFeatureSplitsWithTreeComplexity(self): + def testCalculateBestMultiDimFeatureSplitsWithTreeComplexity_v2_op(self): """Testing best split calculation with tree complexity.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) stats_summary = self._append_zeros_for_default_bucket(stats_summary) - l2 = 0.1 - tree_complexity = 3. - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=0., - l2=l2, - tree_complexity=tree_complexity, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.1, + tree_complexity=3, min_node_weight=0, logits_dimension=1)) @@ -652,29 +799,32 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([1, 2], node_ids) # Gain should also include an adjustment of the gradient by l1. self.assertAllClose([-2.98120904, -2.66068625], gains) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 0], feature_dimensions) self.assertAllEqual([1, 1], thresholds) self.assertAllClose([[-0.5], [0.485294]], left_node_contribs) self.assertAllClose([[-0.043478], [-.6]], right_node_contribs) - self.assertAllEqual([1, 0], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestMultiDimFeatureSplitsWMissingValsTreeComplexity(self): + def testCalculateBestMultiDimFeatureSplitsWMissingValsTreeComplexity_v2_op( + self): """Testing best split calculation with tree complexity.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) - l2 = 0.1 - tree_complexity = 3. - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=0., - l2=l2, - tree_complexity=tree_complexity, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.1, + tree_complexity=3, min_node_weight=0, logits_dimension=1)) @@ -683,38 +833,41 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([1, 2], node_ids) # Gain should also include an adjustment of the gradient by l1. self.assertAllClose([-2.922586, -2.498132], gains) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([1, 1], thresholds) self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs) self.assertAllClose([[0.3125], [0.666667]], right_node_contribs) - self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestMultiDimFeatureEqualitySplitsWithTreeComplexity(self): + def testCalculateBestMultiDimFeatureEqualitySplitsWithTreeComplexity_v2_op( + self): """Testing best split calculation with tree complexity.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. - stats_summary = np.asarray(self._get_stats_summary_for_split()) - # reshape to [max_splits, feature_dim, num_buckets, 2] - stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summaries = self._get_stats_summary_for_split() + # Convert from list of arrays to a single array and reshape to [max_splits, + # feature_dim, num_buckets, 2]. + stats_summary = np.moveaxis(stats_summaries, 0, 1) - l2 = 0.1 - tree_complexity = 3. - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=0., - l2=l2, - tree_complexity=tree_complexity, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['equality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.1, + tree_complexity=3, min_node_weight=0, - logits_dimension=1, - split_type='equality')) + logits_dimension=1)) self.assertAllEqual([1, 2], node_ids) # -2.922586 = 0.05^2/0.16 + 0.36^2/0.67 - 0.31^2/0.73 - 3 # -2.498132 = 0.4^2/0.6 + 0.37^2/0.58 - 0.03^2/1.08 - 3 self.assertAllClose([-2.922586, -2.498132], gains) self.assertAllEqual([2, 2], thresholds) + self.assertAllEqual([4, 4], feature_ids) self.assertAllEqual([1, 1], feature_dimensions) # # The left node contrib will be later added to the previous node value to # # make the left node value, and the same for right node contrib. @@ -751,7 +904,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[0.3125], [0.666667]], right_node_contribs) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateBestGainsWithMinNodeWeight(self): + def testCalculateBestGainsWithMinNodeWeight_v1_op(self): """Testing Gain calculation with min node weight.""" with self.cached_session() as sess: max_splits = 7 @@ -798,8 +951,9 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([[[-0.75]], [[-0.014925]]], self.evaluate(right_node_contribs_list)) - def testCalculateMultiDimBestSplitsWithMinNodeWeight(self): + def testCalculateMultiDimBestSplitsWithMinNodeWeight_v2_op(self): """Testing best split calculation with min node weight.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. stats_summary = np.asarray([ [ @@ -810,7 +964,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 0 + ], # f_dim 0 [ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored [[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1 @@ -819,34 +973,37 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 1 + ], # f_dim 1 ]) # feature_dim * shape=[max_splits, num_buckets, 2] - # reshape to [max_splits, feature_dim, num_buckets, 2] + # Reshape to [max_splits, feature_dim, num_buckets, 2]. stats_summary = np.moveaxis(stats_summary, 0, 1) stats_summary = self._append_zeros_for_default_bucket(stats_summary) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=0., - l2=0., - tree_complexity=0., + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.0, + tree_complexity=0.0, min_node_weight=1, logits_dimension=1)) self.assertAllEqual([1, 2], node_ids) # Gain should also include an adjustment of the gradient by l1. self.assertAllClose([0.098013, 0.931596], gains) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([1, 1], thresholds) self.assertAllClose([[-.6], [-0.315789]], left_node_contribs) self.assertAllClose([[-0.014925], [2.53846]], right_node_contribs) - self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) - def testCalculateMultiDimBestSplitsWithMissingValuesMinNodeWeight(self): + def testCalculateMultiDimBestSplitsWithMissingValuesMinNodeWeight_v2_op(self): """Testing best split calculation with min node weight.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. stats_summary = np.asarray([ [ @@ -857,7 +1014,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 0 + ], # f_dim 0 [ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored [[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1 @@ -866,29 +1023,31 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 1 + ], # f_dim 1 ]) # feature_dim * shape=[max_splits, num_buckets, 2] - # reshape to [max_splits, feature_dim, num_buckets, 2] + # Reshape to [max_splits, feature_dim, num_buckets, 2]. stats_summary = np.moveaxis(stats_summary, 0, 1) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, - l1=0., - l2=0., - tree_complexity=0., + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, + l1=0.0, + l2=0.0, + tree_complexity=0.0, min_node_weight=1, logits_dimension=1)) self.assertAllEqual([1, 2], node_ids) # Gain should also include an adjustment of the gradient by l1. self.assertAllClose([0.149398, 3.332075], gains) + self.assertAllEqual([4, 4], feature_ids) + self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([1, 1], thresholds) self.assertAllClose([[-0.631579], [-0.359223]], left_node_contribs) self.assertAllClose([[0.083333], [7.999989]], right_node_contribs) - self.assertAllEqual([1, 1], feature_dimensions) self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types) def testSparseCalculateBestSplitsWithMinNodeWeight(self): @@ -942,7 +1101,8 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self.assertAllEqual([_INEQUALITY_DEFAULT_RIGHT, _INEQUALITY_DEFAULT_LEFT], split_types) - def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self): + def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeaturePossible_v1_op( + self): """Testing Gain calculation without any regularization.""" with self.cached_session() as sess: max_splits = 7 @@ -995,8 +1155,10 @@ class StatsOpsTest(test_util.TensorFlowTestCase): max_splits=max_splits) self.assertAllEqual([[], []], self.evaluate(node_ids_list)) - def testCalculateBestMultiDimFeatureSplitsWithNoSplitOnFeaturePossible(self): + def testCalculateBestMultiDimFeatureSplitsWithNoSplitOnFeaturePossible_v2_op( + self): """Testing best split calculation with min node weight and no split.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. stats_summary = np.asarray([ [ @@ -1007,7 +1169,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 0 + ], # f_dim 0 [ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored [[0., 0.], [.3, .5], [-.05, .06], [.06, .7]], # node 1 @@ -1016,29 +1178,32 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 1 + ], # f_dim 1 ]) # feature_dim * shape=[max_splits, num_buckets, 2] - # reshape to [max_splits, feature_dim, num_buckets, 2] + # Reshape to [max_splits, feature_dim, num_buckets, 2]. stats_summary = np.moveaxis(stats_summary, 0, 1) + stats_summary = self._append_zeros_for_default_bucket(stats_summary) - (node_ids, _, _, _, _, _, - _) = boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, _, _, _, _, _, _, + _) = boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, min_node_weight=1, logits_dimension=1) - # We can't split either of the nodes on the first feature + # We can't split either of the nodes on the first feature. self.assertAllEqual([1], node_ids) - # Now check when we can't split on any feature - (node_ids, _, _, _, _, _, - _) = boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + # Now check when we can't split on any feature. + (node_ids, _, _, _, _, _, _, + _) = boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1046,8 +1211,10 @@ class StatsOpsTest(test_util.TensorFlowTestCase): logits_dimension=1) self.assertAllEqual([], node_ids) - def testCalculateBestMultiDimFeatureEqualitySplitsWithNoSplitPossible(self): + def testCalculateBestMultiDimFeatureEqualitySplitsWithNoSplitPossible_v2_op( + self): """Testing best split calculation with min node weight and no split.""" + candidate_feature_ids = [4] node_id_range = [1, 3] # node 1 through 2 will be processed. stats_summary = np.asarray([ [ @@ -1058,7 +1225,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 0 + ], # f_dim 0 [ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored [[0., 0.], [.3, .5], [-.05, .06], [.06, .7]], # node 1 @@ -1067,30 +1234,31 @@ class StatsOpsTest(test_util.TensorFlowTestCase): [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored - ], # feature 1 + ], # f_dim 1 ]) # feature_dim * shape=[max_splits, num_buckets, 2] - # reshape to [max_splits, feature_dim, num_buckets, 2] + # Reshape to [max_splits, feature_dim, num_buckets, 2]. stats_summary = np.moveaxis(stats_summary, 0, 1) - (node_ids, _, _, _, _, _, - _) = boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, _, _, _, _, _, _, + _) = boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['equality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, min_node_weight=1, - logits_dimension=1, - split_type='equality') + logits_dimension=1) # We can't split either of the nodes on the first feature self.assertAllEqual([1], node_ids) # Now check when we can't split on any feature - (node_ids, _, _, _, _, _, - _) = boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, _, _, _, _, _, _, + _) = boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['equality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1502,8 +1670,8 @@ class StatsOpsTest(test_util.TensorFlowTestCase): self._verify_precision(length=50000000) -class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): - """Tests multi-class/multi-regression for best splits.""" +class BestMultiDimFeatureSplitMultiClassV2Op(StatsOpsTest): + """Tests multi-class/multi-regression for best splits using V2 op.""" logits_dim = 2 @@ -1566,6 +1734,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsSingleClassVsMultiClass(self): """Testing same results using same grads/hess with both single and multi.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # Build same stats summary in single class and multi-class form (using @@ -1589,23 +1758,25 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): # [max_splits, feature_dim, num_buckets, 4] diag_stats_summary = self._add_feature_dim(diag_stats_summary) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, min_node_weight=0, logits_dimension=1)) - (diag_node_ids, diag_gains, diag_feature_dimensions, diag_thresholds, - diag_left_node_contribs, diag_right_node_contribs, + (diag_node_ids, diag_gains, diag_feature_ids, diag_feature_dimensions, + diag_thresholds, diag_left_node_contribs, diag_right_node_contribs, diag_split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - diag_stats_summary, + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [diag_stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1614,8 +1785,9 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual(node_ids, diag_node_ids) self.assertAllClose(gains, diag_gains) - self.assertAllEqual(thresholds, diag_thresholds) + self.assertAllEqual(feature_ids, diag_feature_ids) self.assertAllEqual(feature_dimensions, diag_feature_dimensions) + self.assertAllEqual(thresholds, diag_thresholds) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. zeros = np.zeros_like(left_node_contribs) @@ -1629,6 +1801,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsDiagonalVsFull(self): """Test results are same using diagonal hessian and full hessian.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # Build same stats summary in diagonal and full hessian form, respectively. @@ -1651,24 +1824,26 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): ] # [max_splits, feature_dim, num_buckets, logits_dim + logits_dim**2] full_stats_summary = self._add_feature_dim(full_stats_summary) - (diag_node_ids, diag_gains, diag_feature_dimensions, diag_thresholds, - diag_left_node_contribs, diag_right_node_contribs, + (diag_node_ids, diag_gains, diag_feature_ids, diag_feature_dimensions, + diag_thresholds, diag_left_node_contribs, diag_right_node_contribs, diag_split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - diag_stats_summary, + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [diag_stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, min_node_weight=0, logits_dimension=self.logits_dim)) - (full_node_ids, full_gains, full_feature_dimensions, full_thresholds, - full_left_node_contribs, full_right_node_contribs, + (full_node_ids, full_gains, full_feature_ids, full_feature_dimensions, + full_thresholds, full_left_node_contribs, full_right_node_contribs, full_split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - full_stats_summary, + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [full_stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1677,8 +1852,9 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual(diag_node_ids, full_node_ids) self.assertAllClose(diag_gains, full_gains) - self.assertAllEqual(diag_thresholds, full_thresholds) + self.assertAllEqual(diag_feature_ids, full_feature_ids) self.assertAllEqual(diag_feature_dimensions, full_feature_dimensions) + self.assertAllEqual(diag_thresholds, full_thresholds) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. self.assertAllClose(diag_left_node_contribs, full_left_node_contribs) @@ -1687,16 +1863,18 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsWithoutRegularization(self): """Testing best split calculation without any regularization.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # [max_splits, feature_dim, num_buckets, 2*logits_dim] stats_summary = self._get_stats_summary_for_split_diagonal_hessian() stats_summary = self._append_zeros_for_default_bucket(stats_summary) - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1706,6 +1884,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual([1, 2], node_ids) self.assertAllClose([0.912981, 1.446218], gains) self.assertAllEqual([2, 1], thresholds) + self.assertAllEqual([14, 14], feature_ids) self.assertAllEqual([0, 1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. @@ -1717,15 +1896,17 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsWMissingValuesWoRegularization(self): """Testing best split calculation without any regularization.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # [max_splits, feature_dim, num_buckets, 2*logits_dim] stats_summary = self._get_stats_summary_for_split_diagonal_hessian() - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1735,6 +1916,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual([1, 2], node_ids) self.assertAllClose([0.912981, 2.79444], gains) self.assertAllEqual([0, 1], thresholds) + self.assertAllEqual([14, 14], feature_ids) self.assertAllEqual([0, 1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. @@ -1746,17 +1928,19 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsWithL2(self): """Testing best split calculation inith L2 regularization.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # [max_splits, feature_dim, num_buckets, 2*logits_dim] stats_summary = self._get_stats_summary_for_split_diagonal_hessian() stats_summary = self._append_zeros_for_default_bucket(stats_summary) l2 = 0.1 - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=l2, tree_complexity=0.0, @@ -1766,6 +1950,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual([1, 2], node_ids) self.assertAllClose([0.475669, 1.009791], gains) self.assertAllEqual([1, 1], thresholds) + self.assertAllEqual([14, 14], feature_ids) self.assertAllEqual([0, 1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. @@ -1777,16 +1962,18 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsWithMissingValuesL2(self): """Testing best split calculation inith L2 regularization.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # [max_splits, feature_dim, num_buckets, 2*logits_dim] stats_summary = self._get_stats_summary_for_split_diagonal_hessian() l2 = 0.1 - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=l2, tree_complexity=0.0, @@ -1796,6 +1983,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual([1, 2], node_ids) self.assertAllClose([0.475669, 3.467833], gains) self.assertAllEqual([1, 0], thresholds) + self.assertAllEqual([14, 14], feature_ids) self.assertAllEqual([0, 1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. @@ -1808,15 +1996,17 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsWithMinNodeWeight(self): """Testing best split calculation with min_node_weight.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # [max_splits, feature_dim, num_buckets, 2*logits_dim] stats_summary = self._get_stats_summary_for_split_diagonal_hessian() - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1827,6 +2017,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual([1, 2], node_ids) self.assertAllClose([0.912981, 2.79444], gains) self.assertAllEqual([0, 1], thresholds) + self.assertAllEqual([14, 14], feature_ids) self.assertAllEqual([0, 1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. @@ -1838,17 +2029,19 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsWithTreeComplexity(self): """Testing best split calculation with tree complexity.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # [max_splits, feature_dim, num_buckets, 2*logits_dim] stats_summary = self._get_stats_summary_for_split_diagonal_hessian() l2 = 0.1 tree_complexity = 3. - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=l2, tree_complexity=tree_complexity, @@ -1860,6 +2053,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): # L2 test result, but subtracted by tree_complexity. self.assertAllClose([-2.524331, 0.467833], gains) self.assertAllEqual([1, 0], thresholds) + self.assertAllEqual([14, 14], feature_ids) self.assertAllEqual([0, 1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. @@ -1872,16 +2066,18 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): def testCalculateBestFeatureSplitsWithMinNodeNoSplitOnFeaturePossible(self): """Test when parent node hessian doesn't meet min node weight.""" + candidate_feature_ids = [14] node_id_range = [1, 3] # node 1 through 2 will be processed. # [max_splits, feature_dim, num_buckets, 2*logits_dim] stats_summary = self._get_stats_summary_for_split_diagonal_hessian() min_node_weight = 0.8 - (node_ids, gains, feature_dimensions, thresholds, left_node_contribs, - right_node_contribs, split_types) = self.evaluate( - boosted_trees_ops.calculate_best_feature_split( - node_id_range, - stats_summary, + (node_ids, gains, feature_ids, feature_dimensions, thresholds, + left_node_contribs, right_node_contribs, split_types) = self.evaluate( + boosted_trees_ops.calculate_best_feature_split_v2( + node_id_range, [stats_summary], + split_types=['inequality'], + candidate_feature_ids=candidate_feature_ids, l1=0.0, l2=0.0, tree_complexity=0.0, @@ -1892,6 +2088,7 @@ class BestMultiDimFeatureSplitMultiClass(StatsOpsTest): self.assertAllEqual([2], node_ids) self.assertAllClose([2.79444], gains) self.assertAllEqual([1], thresholds) + self.assertAllEqual([14], feature_ids) self.assertAllEqual([1], feature_dimensions) # The left node contrib will be later added to the previous node value to # make the left node value, and the same for right node contrib. diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py index 844b428a396..354180f8484 100644 --- a/tensorflow/python/ops/boosted_trees_ops.py +++ b/tensorflow/python/ops/boosted_trees_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import resources from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_aggregate_stats from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split as calculate_best_feature_split +from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split_v2 as calculate_best_feature_split_v2 from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 2441232462d..e4bd8c56389 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -492,6 +492,10 @@ tf_module { name: "BoostedTreesCalculateBestFeatureSplit" argspec: "args=[\'node_id_range\', \'stats_summary\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], " } + member_method { + name: "BoostedTreesCalculateBestFeatureSplitV2" + argspec: "args=[\'node_id_range\', \'stats_summaries_list\', \'split_types\', \'candidate_feature_ids\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "BoostedTreesCalculateBestGainsPerFeature" argspec: "args=[\'node_id_range\', \'stats_summary_list\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'max_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 2441232462d..e4bd8c56389 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -492,6 +492,10 @@ tf_module { name: "BoostedTreesCalculateBestFeatureSplit" argspec: "args=[\'node_id_range\', \'stats_summary\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], " } + member_method { + name: "BoostedTreesCalculateBestFeatureSplitV2" + argspec: "args=[\'node_id_range\', \'stats_summaries_list\', \'split_types\', \'candidate_feature_ids\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "BoostedTreesCalculateBestGainsPerFeature" argspec: "args=[\'node_id_range\', \'stats_summary_list\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'max_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "