Best feature split operates on a list of tensors.

PiperOrigin-RevId: 287242453
Change-Id: I7b2abe5d922a8000f96f92f81f074553b4e205e7
This commit is contained in:
A. Unique TensorFlower 2019-12-26 17:23:47 -08:00 committed by TensorFlower Gardener
parent 621679cbc0
commit a16a650ff0
7 changed files with 1057 additions and 273 deletions

View File

@ -0,0 +1,124 @@
op {
graph_op_name: "BoostedTreesCalculateBestFeatureSplitV2"
visibility: HIDDEN
in_arg {
name: "node_id_range"
description: <<END
A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
END
}
in_arg {
name: "stats_summaries_list"
description: <<END
A list of Rank 4 tensor (#shape=[max_splits, feature_dims, bucket, stats_dims]) for accumulated stats summary (gradient/hessian) per node, per dimension, per buckets for each feature.
The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
END
}
in_arg {
name: "split_types"
description: <<END
A Rank 1 tensor indicating if this Op should perform inequality split or equality split per feature.
END
}
in_arg {
name: "candidate_feature_ids"
description: <<END
Rank 1 tensor with ids for each feature. This is the real id of the feature.
END
}
in_arg {
name: "l1"
description: <<END
l1 regularization factor on leaf weights, per instance based.
END
}
in_arg {
name: "l2"
description: <<END
l2 regularization factor on leaf weights, per instance based.
END
}
in_arg {
name: "tree_complexity"
description: <<END
adjustment to the gain, per leaf based.
END
}
in_arg {
name: "min_node_weight"
description: <<END
mininum avg of hessians in a node before required for the node to be considered for splitting.
END
}
out_arg {
name: "node_ids"
description: <<END
A Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.
END
}
out_arg {
name: "gains"
description: <<END
A Rank 1 tensor indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.
END
}
out_arg {
name: "feature_ids"
description: <<END
A Rank 1 tensors indicating the best feature id for each node. See above for details like shapes and sizes.
END
}
out_arg {
name: "feature_dimensions"
description: <<END
A Rank 1 tensors indicating the best feature dimension for each feature to split for certain nodes if the feature is multi-dimension. See above for details like shapes and sizes.
END
}
out_arg {
name: "thresholds"
description: <<END
A Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.
END
}
out_arg {
name: "left_node_contribs"
description: <<END
A Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.
END
}
out_arg {
name: "right_node_contribs"
description: <<END
A Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
END
}
out_arg {
name: "split_with_default_directions"
description: <<END
A Rank 1 tensors indicating the which direction to go if data is missing. See above for details like shapes and sizes.
Inequality with default left returns 0, inequality with default right returns 1, equality with default right returns 2.
END
}
attr {
name: "num_features"
description: <<END
inferred from the size of `stats_summary_list`; the number of total features.
END
}
attr {
name: "logits_dimension"
description: <<END
The dimension of logit, i.e., number of classes.
END
}
summary: "Calculates gains for each feature and returns the best possible split information for each node. However, if no split is found, then no split information is returned for that node."
description: <<END
The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
The output shapes are compatible in a way that the first dimension of all tensors are the same and equal to the number of possible split nodes for each feature.
END
}

View File

@ -34,7 +34,10 @@ using MatrixMap = Eigen::Map<Matrix>;
using ConstVectorMap = Eigen::Map<const Eigen::VectorXf>;
using VectorMap = Eigen::Map<Eigen::VectorXf>;
// V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOp is V2.
constexpr char kInequalitySplit[] = "inequality";
constexpr char kEqualitySplit[] = "equality";
// V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOpV2 is V2.
class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
public:
explicit BoostedTreesCalculateBestGainsPerFeatureOp(
@ -227,7 +230,7 @@ REGISTER_KERNEL_BUILDER(
Name("BoostedTreesCalculateBestGainsPerFeature").Device(DEVICE_CPU),
BoostedTreesCalculateBestGainsPerFeatureOp);
// V2 Op.
// Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
public:
explicit BoostedTreesCalculateBestFeatureSplitOp(
@ -545,11 +548,394 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
string split_type_;
};
// v2 op that supports multi-class.
// Deprecated op. Use BoostedTreesCalculateBestFeatureSplitOpV2.
REGISTER_KERNEL_BUILDER(
Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU),
BoostedTreesCalculateBestFeatureSplitOp);
// V2 Op.
class BoostedTreesCalculateBestFeatureSplitV2 : public OpKernel {
public:
explicit BoostedTreesCalculateBestFeatureSplitV2(
OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
}
void Compute(OpKernelContext* const context) override {
// node_id_range
const Tensor* node_id_range_t;
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
const auto node_id_range = node_id_range_t->vec<int32>();
const int32 node_id_first = node_id_range(0); // Inclusive.
const int32 node_id_last = node_id_range(1); // Exclusive.
// Get stats_summaries_list.
OpInputList stats_summaries_list;
OP_REQUIRES_OK(context, context->input_list("stats_summaries_list",
&stats_summaries_list));
// Infer dimensions of a stats_summary.
DCHECK_GT(stats_summaries_list.size(), 0);
const int32 feature_dims = stats_summaries_list[0].dim_size(1);
// The last bucket is for default/missing value.
const int32 num_buckets = stats_summaries_list[0].dim_size(2) - 1;
const int32 logits_dim = logits_dim_;
const int32 hessian_dim = stats_summaries_list[0].dim_size(3) - logits_dim;
DCHECK_GT(hessian_dim, 0);
DCHECK_LE(hessian_dim, logits_dim * logits_dim);
// Vector of stats_summaries; each element is stats for feature of shape
// [max_splits, feature_dim, num_buckets, logits_dim + hessian_dim].
std::vector<TTypes<float, 4>::ConstTensor> stats_summaries;
DCHECK_EQ(stats_summaries_list.size(), num_features_);
stats_summaries.reserve(num_features_);
for (const auto& tensor : stats_summaries_list) {
stats_summaries.emplace_back(tensor.tensor<float, 4>());
}
// Split types.
const Tensor* split_types_t;
OP_REQUIRES_OK(context, context->input("split_types", &split_types_t));
const auto split_types = split_types_t->vec<string>();
DCHECK_EQ(split_types.size(), num_features_);
// Validate.
for (int i = 0; i < num_features_; ++i) {
if (!(split_types(i) == kInequalitySplit ||
split_types(i) == kEqualitySplit)) {
OP_REQUIRES_OK(
context,
errors::Aborted(
"Operation received an exception: Incorrect split type"));
}
}
// Feature ids.
const Tensor* candidate_feature_ids_t;
OP_REQUIRES_OK(context, context->input("candidate_feature_ids",
&candidate_feature_ids_t));
const auto candidate_feature_ids = candidate_feature_ids_t->vec<int32>();
DCHECK_EQ(candidate_feature_ids.size(), num_features_);
// L1, L2, tree_complexity, min_node_weight.
const Tensor* l1_t;
OP_REQUIRES_OK(context, context->input("l1", &l1_t));
const auto l1 = l1_t->scalar<float>()();
DCHECK_GE(l1, 0);
if (logits_dim_ > 1) {
// Multi-class L1 regularization not supported yet.
DCHECK_EQ(l1, 0);
}
const Tensor* l2_t;
OP_REQUIRES_OK(context, context->input("l2", &l2_t));
const auto l2 = l2_t->scalar<float>()();
DCHECK_GE(l2, 0);
const Tensor* tree_complexity_t;
OP_REQUIRES_OK(context,
context->input("tree_complexity", &tree_complexity_t));
const auto tree_complexity = tree_complexity_t->scalar<float>()();
const Tensor* min_node_weight_t;
OP_REQUIRES_OK(context,
context->input("min_node_weight", &min_node_weight_t));
const auto min_node_weight = min_node_weight_t->scalar<float>()();
std::vector<int32> output_node_ids;
std::vector<float> output_gains;
std::vector<int32> output_feature_ids;
std::vector<int32> output_feature_dimensions;
std::vector<int32> output_thresholds;
std::vector<Eigen::VectorXf> output_left_node_contribs;
std::vector<Eigen::VectorXf> output_right_node_contribs;
std::vector<string> output_split_types;
// TODO(tanzheny) parallelize the computation.
// Iterate each node and find the best gain per node.
float parent_gain;
for (int32 node_id = node_id_first; node_id < node_id_last; ++node_id) {
float best_gain = std::numeric_limits<float>::lowest();
int32 best_bucket;
int32 best_f_id;
int32 best_f_dim;
string best_split_type;
Eigen::VectorXf best_contrib_for_left(logits_dim);
Eigen::VectorXf best_contrib_for_right(logits_dim);
// Sum of gradient and hessian. Compute parent gain using first feature.
ConstMatrixMap stats_mat(&stats_summaries[0](node_id, 0, 0, 0),
num_buckets + 1, // Including default bucket.
logits_dim + hessian_dim);
const Eigen::VectorXf total_grad =
stats_mat.leftCols(logits_dim).colwise().sum();
const Eigen::VectorXf total_hess =
stats_mat.rightCols(hessian_dim).colwise().sum();
if (total_hess.norm() < min_node_weight) {
continue;
}
Eigen::VectorXf unused(logits_dim);
CalculateWeightsAndGains(total_grad, total_hess, l1, l2, &unused,
&parent_gain);
for (int f_idx = 0; f_idx < num_features_; ++f_idx) {
const string split_type = split_types(f_idx);
TTypes<float, 4>::ConstTensor stats_summary = stats_summaries[f_idx];
float f_best_gain = std::numeric_limits<float>::lowest();
int32 f_best_bucket;
int32 f_best_f_dim;
string f_best_split_type;
Eigen::VectorXf f_best_contrib_for_left(logits_dim);
Eigen::VectorXf f_best_contrib_for_right(logits_dim);
if (split_type == kInequalitySplit) {
CalculateBestInequalitySplit(
stats_summary, node_id, feature_dims, logits_dim, hessian_dim,
num_buckets, min_node_weight, l1, l2, &f_best_gain,
&f_best_bucket, &f_best_f_dim, &f_best_split_type,
&f_best_contrib_for_left, &f_best_contrib_for_right);
} else {
CalculateBestEqualitySplit(
stats_summary, total_grad, total_hess, node_id, feature_dims,
logits_dim, hessian_dim, num_buckets, l1, l2, &f_best_gain,
&f_best_bucket, &f_best_f_dim, &f_best_split_type,
&f_best_contrib_for_left, &f_best_contrib_for_right);
}
if (f_best_gain > best_gain) {
best_gain = f_best_gain;
best_f_id = candidate_feature_ids(f_idx);
best_f_dim = f_best_f_dim;
best_split_type = f_best_split_type;
best_bucket = f_best_bucket;
best_contrib_for_left = f_best_contrib_for_left;
best_contrib_for_right = f_best_contrib_for_right;
}
} // For feature id.
if (best_gain == std::numeric_limits<float>::lowest()) {
// Do not add the node if no split is found.
continue;
}
output_node_ids.push_back(node_id);
// Remove the parent gain for the parent node.
output_gains.push_back(best_gain - parent_gain);
output_feature_ids.push_back(best_f_id);
output_feature_dimensions.push_back(best_f_dim);
// Default direction is fixed for dense splits.
// TODO(tanzheny) account for default values.
output_split_types.push_back(best_split_type);
output_thresholds.push_back(best_bucket);
output_left_node_contribs.push_back(best_contrib_for_left);
output_right_node_contribs.push_back(best_contrib_for_right);
} // for node id.
const int num_nodes = output_node_ids.size();
// output_node_ids
Tensor* output_node_ids_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
&output_node_ids_t));
auto output_node_ids_vec = output_node_ids_t->vec<int32>();
// output_gains
Tensor* output_gains_t;
OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
&output_gains_t));
auto output_gains_vec = output_gains_t->vec<float>();
// output_feature_ids
Tensor* output_features_ids_t;
OP_REQUIRES_OK(context, context->allocate_output("feature_ids", {num_nodes},
&output_features_ids_t));
auto output_features_vec = output_features_ids_t->vec<int32>();
// output_feature_dimensions
Tensor* output_feature_dimension_t;
OP_REQUIRES_OK(context,
context->allocate_output("feature_dimensions", {num_nodes},
&output_feature_dimension_t));
auto output_feature_dimensions_vec =
output_feature_dimension_t->vec<int32>();
// output_thresholds
Tensor* output_thresholds_t;
OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
&output_thresholds_t));
auto output_thresholds_vec = output_thresholds_t->vec<int32>();
// output_left_node_contribs
Tensor* output_left_node_contribs_t;
OP_REQUIRES_OK(context, context->allocate_output(
"left_node_contribs", {num_nodes, logits_dim},
&output_left_node_contribs_t));
auto output_left_node_contribs_matrix =
output_left_node_contribs_t->matrix<float>();
// output_right_node_contribs
Tensor* output_right_node_contribs_t;
OP_REQUIRES_OK(context, context->allocate_output(
"right_node_contribs", {num_nodes, logits_dim},
&output_right_node_contribs_t));
auto output_right_node_contribs_matrix =
output_right_node_contribs_t->matrix<float>();
// split type
Tensor* output_split_types_t;
OP_REQUIRES_OK(
context, context->allocate_output("split_with_default_directions",
{num_nodes}, &output_split_types_t));
auto output_split_types_vec = output_split_types_t->vec<tstring>();
// Sets output tensors from vectors.
for (int i = 0; i < num_nodes; ++i) {
output_node_ids_vec(i) = output_node_ids[i];
output_features_vec(i) = output_feature_ids[i];
// Adjust the gains to penalize by tree complexity.
output_gains_vec(i) = output_gains[i] - tree_complexity;
output_feature_dimensions_vec(i) = output_feature_dimensions[i];
output_thresholds_vec(i) = output_thresholds[i];
for (int j = 0; j < logits_dim; ++j) {
output_left_node_contribs_matrix(i, j) =
output_left_node_contribs[i][j];
output_right_node_contribs_matrix(i, j) =
output_right_node_contribs[i][j];
}
output_split_types_vec(i) = output_split_types[i];
}
}
private:
// TODO(crawles): Simplify inequality path just like equality b/138329196
// Currently this is not simplify-able due to numerical instability in math
// i.e. gain = -g.transpose() * hessian_and_reg.colPivHouseholderQr().solve(g)
// It caused gain to be Inf when g is approaching 0 but not exactly 0 while
// there is no regularization.
// Calculate the best inequality split per node.
void CalculateBestInequalitySplit(
TTypes<float, 4>::ConstTensor stats_summary, const int32 node_id,
const int32 feature_dims, const int32 logits_dim, const int32 hessian_dim,
const int32 num_buckets, const float min_node_weight, const float l1,
const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
Eigen::VectorXf* best_contrib_for_right) {
std::vector<Eigen::VectorXf> cum_grad;
std::vector<Eigen::VectorXf> cum_hess;
// get all cumulative gradients including default bucket.
cum_grad.reserve(num_buckets);
cum_hess.reserve(num_buckets);
for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
ConstVectorMap default_stats_vec(
&stats_summary(node_id, f_dim, num_buckets, 0),
logits_dim + hessian_dim);
Eigen::VectorXf missing_bucket_grad = default_stats_vec.head(logits_dim);
Eigen::VectorXf missing_bucket_hess = default_stats_vec.tail(hessian_dim);
cum_grad.clear();
cum_hess.clear();
Eigen::VectorXf total_grad = Eigen::VectorXf::Zero(logits_dim);
Eigen::VectorXf total_hess = Eigen::VectorXf::Zero(hessian_dim);
// sum all the gradients including default bucket.
for (int bucket = 0; bucket <= num_buckets; ++bucket) {
for (int i = 0; i < logits_dim; ++i) {
total_grad[i] += stats_summary(node_id, f_dim, bucket, i);
}
for (int i = 0; i < hessian_dim; ++i) {
// Full hessian.
total_hess[i] +=
stats_summary(node_id, f_dim, bucket, logits_dim + i);
}
if (bucket < num_buckets) {
cum_grad.push_back(total_grad);
cum_hess.push_back(total_hess);
}
}
const string kInequalityDefaultLeft =
boosted_trees::SplitTypeWithDefault_Name(
boosted_trees::INEQUALITY_DEFAULT_LEFT);
const string kInequalityDefaultRight =
boosted_trees::SplitTypeWithDefault_Name(
boosted_trees::INEQUALITY_DEFAULT_RIGHT);
// Iterate from left to right, excluding default bucket.
for (int bucket = 0; bucket < num_buckets; ++bucket) {
// default value goes to left node.
const Eigen::VectorXf total_left_grad =
cum_grad[bucket] + missing_bucket_grad;
const Eigen::VectorXf total_left_hess =
cum_hess[bucket] + missing_bucket_hess;
MaybeUpdateBestSplit(
total_left_grad, total_grad - total_left_grad, total_left_hess,
total_hess - total_left_hess, logits_dim, bucket, f_dim, l1, l2,
kInequalityDefaultLeft, best_gain, best_bucket, best_f_dim,
best_split_type, best_contrib_for_left, best_contrib_for_right);
// default value goes to right node.
MaybeUpdateBestSplit(
cum_grad[bucket], total_grad - cum_grad[bucket], cum_hess[bucket],
total_hess - cum_hess[bucket], logits_dim, bucket, f_dim, l1, l2,
kInequalityDefaultRight, best_gain, best_bucket, best_f_dim,
best_split_type, best_contrib_for_left, best_contrib_for_right);
} // for bucket
}
}
// Calculate the best equality split per node.
void CalculateBestEqualitySplit(
TTypes<float, 4>::ConstTensor stats_summary,
const Eigen::VectorXf& total_grad, const Eigen::VectorXf& total_hess,
const int32 node_id, const int32 feature_dims, const int32 logits_dim,
const int32 hessian_dim, const int32 num_buckets, const float l1,
const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
Eigen::VectorXf* best_contrib_for_right) {
const string kEqualityDefaultRight =
boosted_trees::SplitTypeWithDefault_Name(
boosted_trees::EQUALITY_DEFAULT_RIGHT);
for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
for (int bucket = 0; bucket < num_buckets; ++bucket) {
ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
logits_dim + hessian_dim);
Eigen::VectorXf curr_grad = stats_vec.head(logits_dim);
Eigen::VectorXf curr_hess = stats_vec.tail(hessian_dim);
MaybeUpdateBestSplit(curr_grad, total_grad - curr_grad, curr_hess,
total_hess - curr_hess, logits_dim, bucket, f_dim,
l1, l2, kEqualityDefaultRight, best_gain,
best_bucket, best_f_dim, best_split_type,
best_contrib_for_left, best_contrib_for_right);
}
}
}
void MaybeUpdateBestSplit(const Eigen::VectorXf& grad_for_left,
const Eigen::VectorXf& grad_for_right,
const Eigen::VectorXf& hess_for_left,
const Eigen::VectorXf& hess_for_right,
const int32 logits_dim, const int32 bucket,
const int32 f_dim, const float l1, const float l2,
const string split_type, float* best_gain,
int32* best_bucket, int32* best_f_dim,
string* best_split_type,
Eigen::VectorXf* best_contrib_for_left,
Eigen::VectorXf* best_contrib_for_right) {
// Left child.
Eigen::VectorXf contrib_for_left(logits_dim);
float gain_for_left;
CalculateWeightsAndGains(grad_for_left, hess_for_left, l1, l2,
&contrib_for_left, &gain_for_left);
Eigen::VectorXf contrib_for_right(logits_dim);
float gain_for_right;
CalculateWeightsAndGains(grad_for_right, hess_for_right, l1, l2,
&contrib_for_right, &gain_for_right);
if (GainIsLarger(gain_for_left + gain_for_right, *best_gain)) {
*best_gain = gain_for_left + gain_for_right;
*best_bucket = bucket;
*best_f_dim = f_dim;
*best_contrib_for_left = contrib_for_left;
*best_contrib_for_right = contrib_for_right;
*best_split_type = split_type;
}
}
int num_features_;
int logits_dim_;
};
// v2 op that supports multi-class.
REGISTER_KERNEL_BUILDER(
Name("BoostedTreesCalculateBestFeatureSplitV2").Device(DEVICE_CPU),
BoostedTreesCalculateBestFeatureSplitV2);
// Map from bucket id to vector of statistics.
typedef std::map<int32, std::vector<float>> BucketMap;
typedef BucketMap::iterator BucketMapIterator;

View File

@ -141,6 +141,74 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit")
return Status::OK();
});
REGISTER_OP("BoostedTreesCalculateBestFeatureSplitV2")
.Input("node_id_range: int32")
.Input("stats_summaries_list: num_features * float32")
.Input("split_types: string")
.Input("candidate_feature_ids: int32")
.Input("l1: float")
.Input("l2: float")
.Input("tree_complexity: float")
.Input("min_node_weight: float")
.Attr("num_features: int >= 1") // not passed but populated automatically.
.Attr("logits_dimension: int >= 1")
.Output("node_ids: int32")
.Output("gains: float32")
.Output("feature_ids: int32")
.Output("feature_dimensions: int32")
.Output("thresholds: int32")
.Output("left_node_contribs: float32")
.Output("right_node_contribs: float32")
.Output("split_with_default_directions: string")
.SetShapeFn([](shape_inference::InferenceContext* c) {
// Attributes.
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
int logits_dimension;
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
// Inputs.
shape_inference::ShapeHandle unused_shape;
// node id range is rank 1 with 2 values.
shape_inference::ShapeHandle node_id_range_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
TF_RETURN_IF_ERROR(
c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
// Stats summary validation.
shape_inference::ShapeHandle summary_shape_base;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &summary_shape_base));
// All stats summary entries are of the same shape.
for (int i = 1; i < num_features; ++i) {
shape_inference::ShapeHandle summary_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + i), 4, &summary_shape));
TF_RETURN_IF_ERROR(
c->Merge(summary_shape_base, summary_shape, &unused_shape));
}
// Validate rank 1 split_types.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(1 + num_features), 1, &unused_shape));
// Validate rank 1 feature_ids.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(2 + num_features), 1, &unused_shape));
// Validate rank 0: l1, l2, tree_complexity, min_node_weight.
for (int i = 0; i < 4; ++i) {
TF_RETURN_IF_ERROR(
c->WithRank(c->input(3 + num_features + i), 0, &unused_shape));
}
// Output shapes.
ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
c->set_output(0, rank_1_output_shape);
c->set_output(1, rank_1_output_shape);
c->set_output(2, rank_1_output_shape);
c->set_output(3, rank_1_output_shape);
c->set_output(4, rank_1_output_shape);
ShapeHandle contribs_output_shape =
c->MakeShape({c->UnknownDim(), logits_dimension});
c->set_output(5, contribs_output_shape);
c->set_output(6, contribs_output_shape);
c->set_output(7, rank_1_output_shape);
return Status::OK();
});
REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
.Input("node_id_range: int32")
.Input("stats_summary_indices: int32")

View File

@ -27,6 +27,7 @@ from tensorflow.python.ops import resources
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_aggregate_stats
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split as calculate_best_feature_split
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_feature_split_v2 as calculate_best_feature_split_v2
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource

View File

@ -492,6 +492,10 @@ tf_module {
name: "BoostedTreesCalculateBestFeatureSplit"
argspec: "args=[\'node_id_range\', \'stats_summary\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], "
}
member_method {
name: "BoostedTreesCalculateBestFeatureSplitV2"
argspec: "args=[\'node_id_range\', \'stats_summaries_list\', \'split_types\', \'candidate_feature_ids\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BoostedTreesCalculateBestGainsPerFeature"
argspec: "args=[\'node_id_range\', \'stats_summary_list\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'max_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -492,6 +492,10 @@ tf_module {
name: "BoostedTreesCalculateBestFeatureSplit"
argspec: "args=[\'node_id_range\', \'stats_summary\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], "
}
member_method {
name: "BoostedTreesCalculateBestFeatureSplitV2"
argspec: "args=[\'node_id_range\', \'stats_summaries_list\', \'split_types\', \'candidate_feature_ids\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "BoostedTreesCalculateBestGainsPerFeature"
argspec: "args=[\'node_id_range\', \'stats_summary_list\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'max_splits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "