Boosted Trees sparse Best Split Op.
PiperOrigin-RevId: 254117179
This commit is contained in:
parent
c174697c09
commit
8be8521ee5
@ -0,0 +1,118 @@
|
||||
op {
|
||||
graph_op_name: "BoostedTreesSparseCalculateBestFeatureSplit"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "node_id_range"
|
||||
description: <<END
|
||||
A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "stats_summary_indices"
|
||||
description: <<END
|
||||
A Rank 2 int64 tensor of dense shape [N, 4] (N specifies the number of non-zero values) for accumulated stats summary (gradient/hessian) per node per bucket for each feature. The second dimension contains node id, feature dimension, bucket id, and stats dim.
|
||||
stats dim is the sum of logits dimension and hessian dimension, hessian dimension can either be logits dimension if diagonal hessian is used, or logits dimension^2 if full hessian is used.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "stats_summary_values"
|
||||
description: <<END
|
||||
A Rank 1 float tensor of dense shape [N] (N specifies the number of non-zero values), which supplies the values for each element in summary_indices.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "stats_summary_shape"
|
||||
description: <<END
|
||||
A Rank 1 float tensor of dense shape [4], which specifies the dense shape of the sparse tensor, which is [num tree nodes, feature dimensions, num buckets, stats dim].
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l1"
|
||||
description: <<END
|
||||
l1 regularization factor on leaf weights, per instance based.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "l2"
|
||||
description: <<END
|
||||
l2 regularization factor on leaf weights, per instance based.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tree_complexity"
|
||||
description: <<END
|
||||
adjustment to the gain, per leaf based.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "min_node_weight"
|
||||
description: <<END
|
||||
mininum avg of hessians in a node before required for the node to be considered for splitting.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "node_ids"
|
||||
description: <<END
|
||||
A Rank 1 tensor indicating possible node ids that can be split.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "gains"
|
||||
description: <<END
|
||||
A Rank 1 tensor indicating the best gains to split each node.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "feature_dimensions"
|
||||
description: <<END
|
||||
A Rank 1 tensor indicating the best feature dimension for each feature to split for each node.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "thresholds"
|
||||
description: <<END
|
||||
A Rank 1 tensor indicating the bucket id to compare with (as a threshold) for split in each node.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "left_node_contribs"
|
||||
description: <<END
|
||||
A Rank 2 tensor indicating the contribution of the left nodes when branching from parent nodes to the left direction by the given threshold for each feature.
|
||||
This value will be used to make the left node value by adding to the parent node value. Second dimension size is logits dimension.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "right_node_contribs"
|
||||
description: <<END
|
||||
A Rank 2 tensor, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "split_with_default_directions"
|
||||
description: <<END
|
||||
A Rank 1 tensor indicating which direction to go if data is missing.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "logits_dimension"
|
||||
description: <<END
|
||||
The dimension of logit, i.e., number of classes.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "split_type"
|
||||
description: <<END
|
||||
A string indicating if this Op should perform inequality split or equality split.
|
||||
END
|
||||
}
|
||||
summary: "Calculates gains for each feature and returns the best possible split information for the feature."
|
||||
description: <<END
|
||||
The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
|
||||
|
||||
It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
|
||||
|
||||
In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
|
||||
|
||||
The output shapes are compatible in a way that the first dimension of all tensors are the same and equal to the number of possible split nodes for each feature.
|
||||
END
|
||||
}
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
@ -25,6 +26,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
const char INEQUALITY_DEFAULT_LEFT[] = "inequality_default_left";
|
||||
const char INEQUALITY_DEFAULT_RIGHT[] = "inequality_default_right";
|
||||
|
||||
// V1 Op. Deprecated. BoostedTreesCalculateBestFeatureSplitOp is V2.
|
||||
class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
|
||||
@ -439,6 +441,306 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("BoostedTreesCalculateBestFeatureSplit").Device(DEVICE_CPU),
|
||||
BoostedTreesCalculateBestFeatureSplitOp);
|
||||
|
||||
// Map from bucket id to vector of statistics.
|
||||
typedef std::map<int32, std::vector<float>> BucketMap;
|
||||
typedef BucketMap::iterator BucketMapIterator;
|
||||
// Map from feature dimension to BucketMap.
|
||||
typedef std::map<int32, BucketMap> FeatureMap;
|
||||
typedef FeatureMap::iterator FeatureMapIterator;
|
||||
|
||||
class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
|
||||
public:
|
||||
explicit BoostedTreesSparseCalculateBestFeatureSplitOp(
|
||||
OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
// TODO(crawles): Using logits_dim_ for multi-class split.
|
||||
OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
|
||||
// TODO(tanzheny): Using this for equality split.
|
||||
OP_REQUIRES_OK(context, context->GetAttr("split_type", &split_type_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
// node_id_range
|
||||
const Tensor* node_id_range_t;
|
||||
OP_REQUIRES_OK(context, context->input("node_id_range", &node_id_range_t));
|
||||
const auto node_id_range = node_id_range_t->vec<int32>();
|
||||
const int32 node_id_first = node_id_range(0); // inclusive
|
||||
const int32 node_id_last = node_id_range(1); // exclusive
|
||||
|
||||
const Tensor* stats_summary_indices_t;
|
||||
OP_REQUIRES_OK(context, context->input("stats_summary_indices",
|
||||
&stats_summary_indices_t));
|
||||
const auto stats_summary_indices = stats_summary_indices_t->matrix<int32>();
|
||||
const int32 num_sparse_entries = stats_summary_indices_t->dim_size(0);
|
||||
|
||||
const Tensor* stats_summary_values_t;
|
||||
OP_REQUIRES_OK(context, context->input("stats_summary_values",
|
||||
&stats_summary_values_t));
|
||||
const auto stats_summary_values = stats_summary_values_t->vec<float>();
|
||||
|
||||
const Tensor* stats_summary_shape_t;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->input("stats_summary_shape", &stats_summary_shape_t));
|
||||
const auto stats_summary_shape = stats_summary_shape_t->vec<int32>();
|
||||
const int32 num_buckets = stats_summary_shape(2) - 1;
|
||||
const int32 stats_dims = stats_summary_shape(3);
|
||||
|
||||
const Tensor* l1_t;
|
||||
OP_REQUIRES_OK(context, context->input("l1", &l1_t));
|
||||
const auto l1 = l1_t->scalar<float>()();
|
||||
|
||||
const Tensor* l2_t;
|
||||
OP_REQUIRES_OK(context, context->input("l2", &l2_t));
|
||||
const auto l2 = l2_t->scalar<float>()();
|
||||
|
||||
const Tensor* tree_complexity_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("tree_complexity", &tree_complexity_t));
|
||||
const auto tree_complexity = tree_complexity_t->scalar<float>()();
|
||||
|
||||
const Tensor* min_node_weight_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input("min_node_weight", &min_node_weight_t));
|
||||
const auto min_node_weight = min_node_weight_t->scalar<float>()();
|
||||
|
||||
std::vector<int32> output_node_ids;
|
||||
std::vector<float> output_gains;
|
||||
std::vector<int32> output_feature_dimensions;
|
||||
std::vector<int32> output_thresholds;
|
||||
std::vector<float> output_left_node_contribs;
|
||||
std::vector<float> output_right_node_contribs;
|
||||
std::vector<string> output_split_types;
|
||||
|
||||
FeatureMap f_map;
|
||||
|
||||
int32 previous_node_id = -1;
|
||||
for (int idx = 0; idx < num_sparse_entries; ++idx) {
|
||||
int32 node_id = stats_summary_indices(idx, 0);
|
||||
if (node_id != previous_node_id) {
|
||||
process_node(f_map, &output_node_ids, &output_gains,
|
||||
&output_feature_dimensions, &output_thresholds,
|
||||
&output_left_node_contribs, &output_right_node_contribs,
|
||||
&output_split_types, previous_node_id, min_node_weight, l1,
|
||||
l2, num_buckets);
|
||||
f_map.clear();
|
||||
}
|
||||
previous_node_id = node_id;
|
||||
DCHECK_LE(node_id_first, node_id);
|
||||
DCHECK_LT(node_id, node_id_last);
|
||||
const int32 feature_dim = stats_summary_indices(idx, 1);
|
||||
const int32 bucket_id = stats_summary_indices(idx, 2);
|
||||
const int32 stat_dim = stats_summary_indices(idx, 3);
|
||||
std::pair<FeatureMapIterator, bool> const& f_insert_result = f_map.insert(
|
||||
FeatureMapIterator::value_type(feature_dim, BucketMap()));
|
||||
auto& b_map = f_insert_result.first->second;
|
||||
std::pair<BucketMapIterator, bool> const& b_insert_result =
|
||||
b_map.insert(BucketMapIterator::value_type(
|
||||
bucket_id, std::vector<float>(stats_dims)));
|
||||
auto& stats = b_insert_result.first->second;
|
||||
stats[stat_dim] = stats_summary_values(idx);
|
||||
} // for node_id
|
||||
// process the last node id
|
||||
process_node(f_map, &output_node_ids, &output_gains,
|
||||
&output_feature_dimensions, &output_thresholds,
|
||||
&output_left_node_contribs, &output_right_node_contribs,
|
||||
&output_split_types, previous_node_id, min_node_weight, l1, l2,
|
||||
num_buckets);
|
||||
|
||||
const int num_nodes = output_node_ids.size();
|
||||
// output_node_ids
|
||||
Tensor* output_node_ids_t = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("node_ids", {num_nodes},
|
||||
&output_node_ids_t));
|
||||
auto output_node_ids_vec = output_node_ids_t->vec<int32>();
|
||||
|
||||
// output_gains
|
||||
Tensor* output_gains_t;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("gains", {num_nodes},
|
||||
&output_gains_t));
|
||||
auto output_gains_vec = output_gains_t->vec<float>();
|
||||
|
||||
// output_feature_dimensions
|
||||
Tensor* output_feature_dimension_t;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output("feature_dimensions", {num_nodes},
|
||||
&output_feature_dimension_t));
|
||||
auto output_feature_dimensions_vec =
|
||||
output_feature_dimension_t->vec<int32>();
|
||||
|
||||
// output_thresholds
|
||||
Tensor* output_thresholds_t;
|
||||
OP_REQUIRES_OK(context, context->allocate_output("thresholds", {num_nodes},
|
||||
&output_thresholds_t));
|
||||
auto output_thresholds_vec = output_thresholds_t->vec<int32>();
|
||||
|
||||
// output_left_node_contribs
|
||||
Tensor* output_left_node_contribs_t;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output("left_node_contribs", {num_nodes, 1},
|
||||
&output_left_node_contribs_t));
|
||||
auto output_left_node_contribs_matrix =
|
||||
output_left_node_contribs_t->matrix<float>();
|
||||
|
||||
// output_right_node_contribs
|
||||
Tensor* output_right_node_contribs_t;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output("right_node_contribs", {num_nodes, 1},
|
||||
&output_right_node_contribs_t));
|
||||
auto output_right_node_contribs_matrix =
|
||||
output_right_node_contribs_t->matrix<float>();
|
||||
|
||||
// split type
|
||||
Tensor* output_split_types_t;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output("split_with_default_directions",
|
||||
{num_nodes}, &output_split_types_t));
|
||||
auto output_split_types_vec = output_split_types_t->vec<string>();
|
||||
|
||||
// Sets output tensors from vectors.
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
output_node_ids_vec(i) = output_node_ids[i];
|
||||
// Adjust the gains to penalize by tree complexity.
|
||||
output_gains_vec(i) = output_gains[i] - tree_complexity;
|
||||
output_feature_dimensions_vec(i) = output_feature_dimensions[i];
|
||||
output_thresholds_vec(i) = output_thresholds[i];
|
||||
// TODO(crawles): change this for multi-class.
|
||||
output_left_node_contribs_matrix(i, 0) = output_left_node_contribs[i];
|
||||
output_right_node_contribs_matrix(i, 0) = output_right_node_contribs[i];
|
||||
output_split_types_vec(i) = output_split_types[i];
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void process_node(const FeatureMap& f_map,
|
||||
std::vector<int32>* output_node_ids,
|
||||
std::vector<float>* output_gains,
|
||||
std::vector<int32>* output_feature_dimensions,
|
||||
std::vector<int32>* output_thresholds,
|
||||
std::vector<float>* output_left_node_contribs,
|
||||
std::vector<float>* output_right_node_contribs,
|
||||
std::vector<string>* output_split_types,
|
||||
const int32 node_id, const float min_node_weight,
|
||||
const float l1, const float l2, const int32 num_buckets) {
|
||||
float parent_gain;
|
||||
Eigen::VectorXf unused(logits_dim_);
|
||||
Eigen::MatrixXf identity;
|
||||
identity.setIdentity(1, 1);
|
||||
|
||||
// start processing for previous node id.
|
||||
float best_gain = std::numeric_limits<float>::lowest();
|
||||
float best_bucket = 0;
|
||||
float best_f_dim = 0;
|
||||
string best_split_type = INEQUALITY_DEFAULT_LEFT;
|
||||
float best_contrib_for_left = 0.0;
|
||||
float best_contrib_for_right = 0.0;
|
||||
// the sum of gradients including default bucket.
|
||||
float total_grad = 0;
|
||||
// the sum of hessians including default bucket.
|
||||
float total_hess = 0;
|
||||
|
||||
for (auto f_iter = f_map.begin(); f_iter != f_map.end(); ++f_iter) {
|
||||
const int32 feature_dim = f_iter->first;
|
||||
const auto buckets_to_stats_map = f_iter->second;
|
||||
|
||||
// The very last bucket contains stats for missing values.
|
||||
// TODO(crawles): use vector for multi-class.
|
||||
const float default_grad =
|
||||
(buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
|
||||
? 0
|
||||
: buckets_to_stats_map.at(num_buckets)[0]);
|
||||
const float default_hess =
|
||||
(buckets_to_stats_map.find(num_buckets) == buckets_to_stats_map.end()
|
||||
? 0
|
||||
: buckets_to_stats_map.at(num_buckets)[1]);
|
||||
|
||||
if (f_iter == f_map.begin()) {
|
||||
// first get the sum of grads, including default bucket.
|
||||
for (auto b_iter = buckets_to_stats_map.begin();
|
||||
b_iter != buckets_to_stats_map.end(); ++b_iter) {
|
||||
total_grad += b_iter->second[0];
|
||||
total_hess += b_iter->second[1];
|
||||
}
|
||||
if (total_hess < min_node_weight) {
|
||||
// Do not split the node because not enough avg hessian.
|
||||
break;
|
||||
}
|
||||
CalculateWeightsAndGains(total_grad * identity, total_hess * identity,
|
||||
l1, l2, &unused, &parent_gain);
|
||||
}
|
||||
|
||||
float total_left_grad = 0;
|
||||
float total_left_hess = 0;
|
||||
for (auto b_iter = buckets_to_stats_map.begin();
|
||||
b_iter != buckets_to_stats_map.end(); ++b_iter) {
|
||||
const int32 bucket_id = b_iter->first;
|
||||
// total_left_stats should exclude stats from default bucket.
|
||||
if (bucket_id == num_buckets) {
|
||||
break;
|
||||
}
|
||||
// TODO(crawles): vector for multi-class.
|
||||
total_left_grad += b_iter->second[0];
|
||||
total_left_hess += b_iter->second[1];
|
||||
// From left to right, default right.
|
||||
// Left child.
|
||||
Eigen::VectorXf contrib_for_left(1);
|
||||
float gain_for_left;
|
||||
CalculateWeightsAndGains(total_left_grad * identity,
|
||||
total_left_hess * identity, l1, l2,
|
||||
&contrib_for_left, &gain_for_left);
|
||||
// Right child.
|
||||
Eigen::VectorXf contrib_for_right(1);
|
||||
float gain_for_right;
|
||||
CalculateWeightsAndGains((total_grad - total_left_grad) * identity,
|
||||
(total_hess - total_left_hess) * identity, l1,
|
||||
l2, &contrib_for_right, &gain_for_right);
|
||||
if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
|
||||
best_gain = gain_for_left + gain_for_right;
|
||||
best_bucket = bucket_id;
|
||||
best_f_dim = feature_dim;
|
||||
best_split_type = INEQUALITY_DEFAULT_RIGHT;
|
||||
best_contrib_for_left = contrib_for_left[0];
|
||||
best_contrib_for_right = contrib_for_right[0];
|
||||
}
|
||||
|
||||
// From right to left, default left.
|
||||
CalculateWeightsAndGains((total_left_grad + default_grad) * identity,
|
||||
(total_left_hess + default_hess) * identity,
|
||||
l1, l2, &contrib_for_left, &gain_for_left);
|
||||
CalculateWeightsAndGains(
|
||||
(total_grad - default_grad - total_left_grad) * identity,
|
||||
(total_hess - default_hess - total_left_hess) * identity, l1, l2,
|
||||
&contrib_for_right, &gain_for_right);
|
||||
if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
|
||||
best_gain = gain_for_left + gain_for_right;
|
||||
best_bucket = bucket_id;
|
||||
best_f_dim = feature_dim;
|
||||
best_split_type = INEQUALITY_DEFAULT_LEFT;
|
||||
best_contrib_for_left = contrib_for_left[0];
|
||||
best_contrib_for_right = contrib_for_right[0];
|
||||
}
|
||||
} // for bucket_id
|
||||
} // for feature_dim
|
||||
if (best_gain != std::numeric_limits<float>::lowest()) {
|
||||
output_node_ids->push_back(node_id);
|
||||
// Remove the parent gain.
|
||||
output_gains->push_back(best_gain - parent_gain);
|
||||
output_feature_dimensions->push_back(best_f_dim);
|
||||
output_split_types->push_back(best_split_type);
|
||||
output_thresholds->push_back(best_bucket);
|
||||
output_left_node_contribs->push_back(best_contrib_for_left);
|
||||
output_right_node_contribs->push_back(best_contrib_for_right);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int logits_dim_;
|
||||
string split_type_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("BoostedTreesSparseCalculateBestFeatureSplit").Device(DEVICE_CPU),
|
||||
BoostedTreesSparseCalculateBestFeatureSplitOp);
|
||||
|
||||
class BoostedTreesMakeStatsSummaryOp : public OpKernel {
|
||||
public:
|
||||
explicit BoostedTreesMakeStatsSummaryOp(OpKernelConstruction* const context)
|
||||
|
@ -132,6 +132,48 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
|
||||
.Input("node_id_range: int32")
|
||||
.Input("stats_summary_indices: int32")
|
||||
.Input("stats_summary_values: float")
|
||||
.Input("stats_summary_shape: int32")
|
||||
.Input("l1: float")
|
||||
.Input("l2: float")
|
||||
.Input("tree_complexity: float")
|
||||
.Input("min_node_weight: float")
|
||||
.Attr("logits_dimension: int >= 1")
|
||||
.Attr("split_type: {'inequality'} = 'inequality'")
|
||||
.Output("node_ids: int32")
|
||||
.Output("gains: float32")
|
||||
.Output("feature_dimensions: int32")
|
||||
.Output("thresholds: int32")
|
||||
.Output("left_node_contribs: float32")
|
||||
.Output("right_node_contribs: float32")
|
||||
.Output("split_with_default_directions: string")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle node_id_range_shape;
|
||||
shape_inference::ShapeHandle unused_shape;
|
||||
// node id range is rank 1 with 2 values.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &node_id_range_shape));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(node_id_range_shape, c->MakeShape({2}), &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_shape));
|
||||
shape_inference::ShapeHandle summary_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(summary_shape, c->MakeShape({4}), &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape));
|
||||
ShapeHandle output_shape = c->MakeShape({-1});
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
c->set_output(i, output_shape);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("BoostedTreesCreateEnsemble")
|
||||
.Input("tree_ensemble_handle: resource")
|
||||
.Input("stamp_token: int64")
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
_INEQUALITY_DEFAULT_LEFT = 'inequality_default_left'.encode('utf-8')
|
||||
_INEQUALITY_DEFAULT_RIGHT = 'inequality_default_right'.encode('utf-8')
|
||||
|
||||
|
||||
class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||
@ -56,6 +57,98 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||
], # feature 1
|
||||
] # shape=[num_features, max_splits, num_buckets, 2]
|
||||
|
||||
def _get_sparse_stats_summary_for_split(self, stats_summary=None):
|
||||
if stats_summary is None:
|
||||
stats_summary = np.asarray(self._get_stats_summary_for_split())
|
||||
stats_summary[0][0][1] = np.zeros([2])
|
||||
stats_summary[1][0][2] = np.zeros([2])
|
||||
stats_summary = np.moveaxis(stats_summary, 0, 1)
|
||||
slices = stats_summary.nonzero()
|
||||
values = stats_summary[slices]
|
||||
indices = np.asarray(slices)
|
||||
return np.moveaxis(indices, 0, 1), values, stats_summary.shape
|
||||
|
||||
def testCalculateBestSplitsWithoutRegularizationInSparse(self):
|
||||
# This test uses the same data as dense, but run in sparse kernel and
|
||||
# make sure the sparse kernel returns same result as dense kernel.
|
||||
dense_summary = np.asarray([
|
||||
[
|
||||
[[0., 0.], [.0, .0], [0., 0.], [0., 0.]], # node 0; ignored
|
||||
[[0., 0.], [.15, .36], [.06, .07], [.1, .2]], # node 1
|
||||
[[0., 0.], [-.33, .58], [0., 0.], [.3, .4]], # node 2
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
|
||||
], # feature 0
|
||||
[
|
||||
[[0., 0.], [0., 0.], [.0, .0], [0., 0.]], # node 0; ignored
|
||||
[[0., 0.], [.3, .5], [-.05, .06], [.06, .07]], # node 1
|
||||
[[.1, .1], [.2, .3], [-.4, .5], [.07, .08]], # node 2
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
|
||||
], # feature 1
|
||||
]) # num_features * shape=[max_splits, num_buckets, 2]
|
||||
node_id_range = [1, 3]
|
||||
dense_summary = np.moveaxis(dense_summary, 0, 1)
|
||||
dense_shape = dense_summary.shape
|
||||
|
||||
default_bucket_summary = np.zeros(dense_shape[0:2] + (1, dense_shape[3]))
|
||||
sparse_summary = np.concatenate((dense_summary, default_bucket_summary),
|
||||
axis=2)
|
||||
slices = sparse_summary.nonzero()
|
||||
summary_values = sparse_summary[slices]
|
||||
summary_indices = np.asarray(slices)
|
||||
summary_indices = np.moveaxis(summary_indices, 0, 1)
|
||||
summary_shape = sparse_summary.shape
|
||||
|
||||
(node_ids, gains, _, _, left_node_contribs, right_node_contribs,
|
||||
_) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.0,
|
||||
l2=0.0,
|
||||
tree_complexity=0.0,
|
||||
min_node_weight=0,
|
||||
logits_dimension=1))
|
||||
|
||||
self.assertAllEqual([1, 2], node_ids)
|
||||
self.assertAllClose([0.02823, 0.41184], gains)
|
||||
self.assertAllClose([-0.6], left_node_contribs[0])
|
||||
self.assertAllClose([-0.076923], right_node_contribs[0])
|
||||
|
||||
def testSparseCalculateBestSplitsWithoutRegularization(self):
|
||||
node_id_range = [1, 3]
|
||||
(summary_indices, summary_values,
|
||||
summary_shape) = self._get_sparse_stats_summary_for_split()
|
||||
|
||||
(node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
|
||||
right_node_contribs, split_types) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.0,
|
||||
l2=0.0,
|
||||
tree_complexity=0.0,
|
||||
min_node_weight=0,
|
||||
logits_dimension=1))
|
||||
self.assertAllEqual([1, 2], node_ids)
|
||||
self.assertAllClose([0.116495, 0.60429], gains)
|
||||
self.assertAllEqual([1, 1], thresholds)
|
||||
self.assertAllEqual([1, 1], feature_dimensions)
|
||||
# The left node contrib will be later added to the previous node value to
|
||||
# make the left node value, and the same for right node contrib.
|
||||
self.assertAllClose([[-0.631579], [-0.770833]], left_node_contribs)
|
||||
self.assertAllClose([[0.833333], [0.8]], right_node_contribs)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
|
||||
|
||||
def testCalculateBestGainsWithoutRegularization(self):
|
||||
"""Testing Gain calculation without any regularization."""
|
||||
with self.cached_session() as sess:
|
||||
@ -174,6 +267,34 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose([[-.043478], [-.6]], right_node_contribs)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
|
||||
|
||||
def testSparseCalculateBestSplitsWithL2(self):
|
||||
node_id_range = [1, 3]
|
||||
(summary_indices, summary_values,
|
||||
summary_shape) = self._get_sparse_stats_summary_for_split()
|
||||
|
||||
(node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
|
||||
right_node_contribs, split_types) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.0,
|
||||
l2=0.1,
|
||||
tree_complexity=0.0,
|
||||
min_node_weight=0,
|
||||
logits_dimension=1))
|
||||
self.assertAllEqual([1, 2], node_ids)
|
||||
self.assertAllClose([0.077414, 0.501868], gains)
|
||||
self.assertAllEqual([1, 1], feature_dimensions)
|
||||
self.assertAllEqual([1, 1], thresholds)
|
||||
# The left node contrib will be later added to the previous node value to
|
||||
# make the left node value, and the same for right node contrib.
|
||||
self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs)
|
||||
self.assertAllClose([[0.3125], [0.666667]], right_node_contribs)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT, _INEQUALITY_DEFAULT_LEFT],
|
||||
split_types)
|
||||
|
||||
def testCalculateBestGainsWithL1(self):
|
||||
"""Testing Gain calculation with L1."""
|
||||
with self.cached_session() as sess:
|
||||
@ -236,6 +357,33 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual([1, 1], feature_dimensions)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
|
||||
|
||||
def testSparseCalculateBestSplitsWithL1(self):
|
||||
node_id_range = [1, 3]
|
||||
(summary_indices, summary_values,
|
||||
summary_shape) = self._get_sparse_stats_summary_for_split()
|
||||
|
||||
(node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
|
||||
right_node_contribs, split_types) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.1,
|
||||
l2=0.,
|
||||
tree_complexity=0.0,
|
||||
min_node_weight=0,
|
||||
logits_dimension=1))
|
||||
self.assertAllEqual([1, 2], node_ids)
|
||||
self.assertAllClose([0.048597, 0.331875], gains)
|
||||
self.assertAllEqual([1, 1], feature_dimensions)
|
||||
self.assertAllEqual([1, 1], thresholds)
|
||||
# The left node contrib will be later added to the previous node value to
|
||||
# make the left node value, and the same for right node contrib.
|
||||
self.assertAllClose([[-0.45614], [-0.5625]], left_node_contribs)
|
||||
self.assertAllClose([[0.0], [0.6]], right_node_contribs)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
|
||||
|
||||
def testCalculateBestGainsWithTreeComplexity(self):
|
||||
"""Testing best gain calculation with tree complexity."""
|
||||
with self.cached_session() as sess:
|
||||
@ -300,6 +448,33 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual([1, 0], feature_dimensions)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
|
||||
|
||||
def testSparseCalculateBestSplitsWithTreeComplexity(self):
|
||||
"""Testing best split calculation with tree complexity."""
|
||||
node_id_range = [1, 3]
|
||||
(summary_indices, summary_values,
|
||||
summary_shape) = self._get_sparse_stats_summary_for_split()
|
||||
|
||||
(node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
|
||||
right_node_contribs, split_types) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.,
|
||||
l2=0.1,
|
||||
tree_complexity=3.,
|
||||
min_node_weight=0,
|
||||
logits_dimension=1))
|
||||
|
||||
self.assertAllEqual([1, 2], node_ids)
|
||||
self.assertAllClose([-2.922586, -2.498132], gains)
|
||||
self.assertAllEqual([1, 1], feature_dimensions)
|
||||
self.assertAllEqual([1, 1], thresholds)
|
||||
self.assertAllClose([[-0.537313], [-0.637931]], left_node_contribs)
|
||||
self.assertAllClose([[0.3125], [0.666667]], right_node_contribs)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
|
||||
|
||||
def testCalculateBestGainsWithMinNodeWeight(self):
|
||||
"""Testing Gain calculation with min node weight."""
|
||||
with self.cached_session() as sess:
|
||||
@ -393,8 +568,59 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual([1, 1], feature_dimensions)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_LEFT] * 2, split_types)
|
||||
|
||||
def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeaturePossible(self):
|
||||
"""Testing Gain calculation with min node weight and no split."""
|
||||
def testSparseCalculateBestSplitsWithMinNodeWeight(self):
|
||||
"""Testing best split calculation with min node weight."""
|
||||
node_id_range = [1, 3] # node 1 through 2 will be processed.
|
||||
stats_summary = np.asarray([
|
||||
[
|
||||
[[0., 0.], [.0, .0], [0., 0.], [0., 0.]], # node 0; ignored
|
||||
[[0., 0.], [.15, .36], [.06, .61], [.1, .2]], # node 1
|
||||
[[0., 0.], [-.33, .68], [0., 0.], [.3, .4]], # node 2
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
|
||||
], # feature 0
|
||||
[
|
||||
[[0., 0.], [0., 0.], [.0, .0], [0., 0.]], # node 0; ignored
|
||||
[[0., 0.], [-.05, .6], [.3, .5], [.06, .07]], # node 1
|
||||
[[.1, 1.], [.2, -.05], [-.4, .05], [.07, .08]], # node 2
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
|
||||
], # feature 1
|
||||
]) # num_features * shape=[max_splits, num_buckets, 2]
|
||||
# reshape to [max_splits, num_features, num_buckets, 2]
|
||||
stats_summary = np.moveaxis(stats_summary, 0, 1)
|
||||
|
||||
(summary_indices, summary_values,
|
||||
summary_shape) = self._get_sparse_stats_summary_for_split(stats_summary)
|
||||
|
||||
(node_ids, gains, feature_dimensions, thresholds, left_node_contribs,
|
||||
right_node_contribs, split_types) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.,
|
||||
l2=0.,
|
||||
tree_complexity=0.,
|
||||
min_node_weight=1,
|
||||
logits_dimension=1))
|
||||
|
||||
self.assertAllEqual([1, 2], node_ids)
|
||||
self.assertAllClose([0.149398, 3.332079], gains)
|
||||
self.assertAllEqual([1, 1], thresholds)
|
||||
self.assertAllClose([[0.083333], [-0.359223]], left_node_contribs)
|
||||
self.assertAllClose([[-0.631579], [7.999998]], right_node_contribs)
|
||||
self.assertAllEqual([1, 1], feature_dimensions)
|
||||
self.assertAllEqual([_INEQUALITY_DEFAULT_RIGHT, _INEQUALITY_DEFAULT_LEFT],
|
||||
split_types)
|
||||
|
||||
def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self):
|
||||
"""Testing Gain calculation without any regularization."""
|
||||
with self.cached_session() as sess:
|
||||
max_splits = 7
|
||||
node_id_range = [1, 3] # node 1 through 2 will be processed.
|
||||
@ -497,6 +723,63 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||
logits_dimension=1)
|
||||
self.assertAllEqual([], node_ids)
|
||||
|
||||
def testSparseCalculateBestSplitsWithMinNodeWeightNoSplitOnFeature(self):
|
||||
"""Testing best split calculation with min node weight and no split."""
|
||||
node_id_range = [1, 3] # node 1 through 2 will be processed.
|
||||
stats_summary = np.asarray([
|
||||
[
|
||||
[[0., 0.], [.0, .0], [0., 0.], [0., 0.]], # node 0; ignored
|
||||
[[0., 0.], [.15, .36], [.06, .7], [.1, .2]], # node 1
|
||||
[[0., 0.], [-.33, .068], [0., 0.], [.3, .04]], # node 2
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
|
||||
], # feature 0
|
||||
[
|
||||
[[0., 0.], [0., 0.], [.0, .0], [0., 0.]], # node 0; ignored
|
||||
[[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1
|
||||
[[.1, .1], [.2, .03], [-.4, .05], [.07, .08]], # node 2
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
|
||||
[[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
|
||||
], # feature 1
|
||||
]) # num_features * shape=[max_splits, num_buckets, 2]
|
||||
# reshape to [max_splits, num_features, num_buckets, 2]
|
||||
stats_summary = np.moveaxis(stats_summary, 0, 1)
|
||||
(summary_indices, summary_values,
|
||||
summary_shape) = self._get_sparse_stats_summary_for_split(stats_summary)
|
||||
|
||||
(node_ids, _, _, _, _, _, _) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.,
|
||||
l2=0.,
|
||||
tree_complexity=0.,
|
||||
min_node_weight=1,
|
||||
logits_dimension=1))
|
||||
|
||||
# We can't split either of the nodes on the first feature
|
||||
self.assertAllEqual([1], node_ids)
|
||||
|
||||
# Now check when we can't split on any feature
|
||||
(node_ids, _, _, _, _, _, _) = self.evaluate(
|
||||
boosted_trees_ops.sparse_calculate_best_feature_split(
|
||||
node_id_range,
|
||||
summary_indices,
|
||||
summary_values,
|
||||
summary_shape,
|
||||
l1=0.,
|
||||
l2=0.,
|
||||
tree_complexity=0.,
|
||||
min_node_weight=10,
|
||||
logits_dimension=1))
|
||||
self.assertAllEqual([], node_ids)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMakeStatsSummarySimple(self):
|
||||
"""Simple test for MakeStatsSummary."""
|
||||
|
@ -40,6 +40,7 @@ from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_s
|
||||
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
|
||||
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op
|
||||
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_aggregate_stats
|
||||
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_sparse_calculate_best_feature_split as sparse_calculate_best_feature_split
|
||||
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
|
||||
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
|
||||
from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized
|
||||
|
@ -532,6 +532,10 @@ tf_module {
|
||||
name: "BoostedTreesSparseAggregateStats"
|
||||
argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature_indices\', \'feature_values\', \'feature_shape\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BoostedTreesSparseCalculateBestFeatureSplit"
|
||||
argspec: "args=[\'node_id_range\', \'stats_summary_indices\', \'stats_summary_values\', \'stats_summary_shape\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BoostedTreesTrainingPredict"
|
||||
argspec: "args=[\'tree_ensemble_handle\', \'cached_tree_ids\', \'cached_node_ids\', \'bucketized_features\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -532,6 +532,10 @@ tf_module {
|
||||
name: "BoostedTreesSparseAggregateStats"
|
||||
argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature_indices\', \'feature_values\', \'feature_shape\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BoostedTreesSparseCalculateBestFeatureSplit"
|
||||
argspec: "args=[\'node_id_range\', \'stats_summary_indices\', \'stats_summary_values\', \'stats_summary_shape\', \'l1\', \'l2\', \'tree_complexity\', \'min_node_weight\', \'logits_dimension\', \'split_type\', \'name\'], varargs=None, keywords=None, defaults=[\'inequality\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BoostedTreesTrainingPredict"
|
||||
argspec: "args=[\'tree_ensemble_handle\', \'cached_tree_ids\', \'cached_node_ids\', \'bucketized_features\', \'logits_dimension\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user