BoostedTreesUpdateEnsembleV2 kernel logic works on list of feature_ids.

PiperOrigin-RevId: 290134014
Change-Id: I45ec10013c9271c432b94c8c8a84fef65a65e373
This commit is contained in:
A. Unique TensorFlower 2020-01-16 13:35:13 -08:00 committed by TensorFlower Gardener
parent ef8379e48b
commit fa657fb523
5 changed files with 478 additions and 493 deletions

View File

@ -285,7 +285,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension, auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
left_node_id, right_node_id); left_node_id, right_node_id);
auto* new_split = node->mutable_bucketized_split(); auto* new_split = node->mutable_bucketized_split();
new_split->set_feature_id(candidate.feature_idx); new_split->set_feature_id(candidate.feature_id);
new_split->set_threshold(candidate.threshold); new_split->set_threshold(candidate.threshold);
new_split->set_dimension_id(candidate.dimension_id); new_split->set_dimension_id(candidate.dimension_id);
new_split->set_left_id(*left_node_id); new_split->set_left_id(*left_node_id);
@ -310,7 +310,7 @@ void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension, auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
left_node_id, right_node_id); left_node_id, right_node_id);
auto* new_split = node->mutable_categorical_split(); auto* new_split = node->mutable_categorical_split();
new_split->set_feature_id(candidate.feature_idx); new_split->set_feature_id(candidate.feature_id);
new_split->set_value(candidate.threshold); new_split->set_value(candidate.threshold);
new_split->set_dimension_id(candidate.dimension_id); new_split->set_dimension_id(candidate.dimension_id);
new_split->set_left_id(*left_node_id); new_split->set_left_id(*left_node_id);

View File

@ -189,10 +189,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// Get current split candidate. // Get current split candidate.
const auto& node_id = node_ids(candidate_idx); const auto& node_id = node_ids(candidate_idx);
const auto& gain = gains(candidate_idx); const auto& gain = gains(candidate_idx);
const auto& best_split_it = best_split_per_node->find(node_id);
auto best_split_it = best_split_per_node->find(node_id);
boosted_trees::SplitCandidate candidate; boosted_trees::SplitCandidate candidate;
candidate.feature_idx = feature_ids(feature_idx); candidate.feature_id = feature_ids(feature_idx);
candidate.candidate_idx = candidate_idx; candidate.candidate_idx = candidate_idx;
candidate.gain = gain; candidate.gain = gain;
candidate.dimension_id = 0; candidate.dimension_id = 0;
@ -207,8 +206,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
GainsAreEqual(gain, best_split_it->second.gain))) { GainsAreEqual(gain, best_split_it->second.gain))) {
const auto best_candidate = (*best_split_per_node)[node_id]; const auto best_candidate = (*best_split_per_node)[node_id];
const int32 best_feature_id = best_candidate.feature_idx; const int32 best_feature_id = best_candidate.feature_id;
const int32 feature_id = candidate.feature_idx; const int32 feature_id = candidate.feature_id;
VLOG(2) << "Breaking ties on feature ids and buckets"; VLOG(2) << "Breaking ties on feature ids and buckets";
// Breaking ties deterministically. // Breaking ties deterministically.
if (feature_id < best_feature_id) { if (feature_id < best_feature_id) {
@ -235,8 +234,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
public: public:
explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context) explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context)
: OpKernel(context) { : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_)); OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
OP_REQUIRES_OK(context, context->GetAttr("num_groups", &num_groups_));
} }
void Compute(OpKernelContext* const context) override { void Compute(OpKernelContext* const context) override {
@ -272,8 +271,6 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
OpInputList feature_ids_list; OpInputList feature_ids_list;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->input_list("feature_ids", &feature_ids_list)); context->input_list("feature_ids", &feature_ids_list));
// TODO(crawles): Read groups of feature ids and find best splits among all.
const auto feature_ids = feature_ids_list[0].vec<int32>();
const Tensor* max_depth_t; const Tensor* max_depth_t;
OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t)); OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
@ -292,7 +289,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list, FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
thresholds_list, dimension_ids_list, thresholds_list, dimension_ids_list,
left_node_contribs_list, right_node_contribs_list, left_node_contribs_list, right_node_contribs_list,
split_types_list, feature_ids, &best_splits); split_types_list, feature_ids_list, &best_splits);
int32 current_tree = int32 current_tree =
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource); UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
@ -395,38 +392,36 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
const OpInputList& thresholds_list, const OpInputList& dimension_ids_list, const OpInputList& thresholds_list, const OpInputList& dimension_ids_list,
const OpInputList& left_node_contribs_list, const OpInputList& left_node_contribs_list,
const OpInputList& right_node_contribs_list, const OpInputList& right_node_contribs_list,
const OpInputList& split_types_list, const OpInputList& split_types_list, const OpInputList& feature_ids_list,
const TTypes<const int32>::Vec& feature_ids,
std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) { std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
// Find best split per node going through every feature candidate. // Find best split per node going through every feature candidate.
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) { for (int64 group_idx = 0; group_idx < num_groups_; ++group_idx) {
const auto& node_ids = node_ids_list[feature_idx].vec<int32>(); const auto& node_ids = node_ids_list[group_idx].vec<int32>();
const auto& gains = gains_list[feature_idx].vec<float>(); const auto& gains = gains_list[group_idx].vec<float>();
const auto& thresholds = thresholds_list[feature_idx].vec<int32>(); const auto& feature_ids = feature_ids_list[group_idx].vec<int32>();
const auto& dimension_ids = dimension_ids_list[feature_idx].vec<int32>(); const auto& thresholds = thresholds_list[group_idx].vec<int32>();
const auto& dimension_ids = dimension_ids_list[group_idx].vec<int32>();
const auto& left_node_contribs = const auto& left_node_contribs =
left_node_contribs_list[feature_idx].matrix<float>(); left_node_contribs_list[group_idx].matrix<float>();
const auto& right_node_contribs = const auto& right_node_contribs =
right_node_contribs_list[feature_idx].matrix<float>(); right_node_contribs_list[group_idx].matrix<float>();
const auto& split_types = split_types_list[feature_idx].vec<tstring>(); const auto& split_types = split_types_list[group_idx].vec<tstring>();
for (size_t candidate_idx = 0; candidate_idx < node_ids.size(); for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
++candidate_idx) { ++candidate_idx) {
// Get current split candidate. // Get current split candidate.
const auto& node_id = node_ids(candidate_idx); const auto& node_id = node_ids(candidate_idx);
const auto& gain = gains(candidate_idx); const auto& gain = gains(candidate_idx);
const auto& threshold = thresholds(candidate_idx); const auto& feature_id = feature_ids(candidate_idx);
const auto& dimension_id = dimension_ids(candidate_idx);
const auto& split_type = split_types(candidate_idx);
auto best_split_it = best_split_per_node->find(node_id); auto best_split_it = best_split_per_node->find(node_id);
boosted_trees::SplitCandidate candidate; boosted_trees::SplitCandidate candidate;
candidate.feature_idx = feature_ids(feature_idx);
candidate.candidate_idx = candidate_idx; candidate.candidate_idx = candidate_idx;
candidate.gain = gain; candidate.gain = gain;
candidate.threshold = threshold; candidate.feature_id = feature_id;
candidate.dimension_id = dimension_id; candidate.threshold = thresholds(candidate_idx);
candidate.split_type = split_type; candidate.dimension_id = dimension_ids(candidate_idx);
candidate.split_type = split_types(candidate_idx);
for (int i = 0; i < logits_dim_; ++i) { for (int i = 0; i < logits_dim_; ++i) {
candidate.left_node_contribs.push_back( candidate.left_node_contribs.push_back(
learning_rate * left_node_contribs(candidate_idx, i)); learning_rate * left_node_contribs(candidate_idx, i));
@ -435,9 +430,9 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
} }
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() && if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
GainsAreEqual(gain, best_split_it->second.gain))) { GainsAreEqual(gain, best_split_it->second.gain))) {
const auto best_candidate = (*best_split_per_node)[node_id]; const auto& best_candidate = (*best_split_per_node)[node_id];
const int32 best_feature_id = best_candidate.feature_idx; const int32 best_feature_id = best_candidate.feature_id;
const int32 feature_id = candidate.feature_idx; const int32 feature_id = candidate.feature_id;
VLOG(2) << "Breaking ties on feature ids and buckets"; VLOG(2) << "Breaking ties on feature ids and buckets";
// Breaking ties deterministically. // Breaking ties deterministically.
if (feature_id < best_feature_id) { if (feature_id < best_feature_id) {
@ -452,8 +447,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
} }
private: private:
int32 num_features_;
int32 logits_dim_; int32 logits_dim_;
int32 num_groups_;
}; };
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU),

View File

@ -30,12 +30,10 @@ namespace boosted_trees {
struct SplitCandidate { struct SplitCandidate {
SplitCandidate() {} SplitCandidate() {}
// Index in the list of the feature ids.
int64 feature_idx = 0;
// Index in the tensor of node_ids for the feature with idx feature_idx. // Index in the tensor of node_ids for the feature with idx feature_idx.
int64 candidate_idx = 0; int64 candidate_idx = 0;
int64 feature_id = 0;
float gain = 0.0; float gain = 0.0;
int32 threshold = 0.0; int32 threshold = 0.0;
int32 dimension_id = 0; int32 dimension_id = 0;
@ -56,20 +54,20 @@ static bool GainIsLarger(const float g1, const float g2) {
return g1 - g2 >= kTolerance; return g1 - g2 >= kTolerance;
} }
static void MultiDimLogitSolveForWeightAndGain(Eigen::MatrixXf hessian_and_reg, static void MultiDimLogitSolveForWeightAndGain(
Eigen::VectorXf g, const Eigen::MatrixXf& hessian_and_reg, const Eigen::VectorXf& g,
Eigen::VectorXf* weight, Eigen::VectorXf* weight, float* gain) {
float* gain) {
*weight = -hessian_and_reg.colPivHouseholderQr().solve(g); *weight = -hessian_and_reg.colPivHouseholderQr().solve(g);
*gain = -g.transpose() * (*weight); *gain = -g.transpose() * (*weight);
} }
static void CalculateWeightsAndGains(const Eigen::VectorXf g, // Used in stats_ops.cc to determine weights/gains for each feature split.
const Eigen::VectorXf h, const float l1, static void CalculateWeightsAndGains(const Eigen::VectorXf& g,
const Eigen::VectorXf& h, const float l1,
const float l2, Eigen::VectorXf* weight, const float l2, Eigen::VectorXf* weight,
float* gain) { float* gain) {
const float kEps = 1e-15; const float kEps = 1e-15;
int32 logits_dim = g.size(); const int32 logits_dim = g.size();
if (logits_dim == 1) { if (logits_dim == 1) {
// The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
// (g+l1*sgn(w))^2/(h+l2). // (g+l1*sgn(w))^2/(h+l2).

View File

@ -631,60 +631,61 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
.Input("pruning_mode: int32") .Input("pruning_mode: int32")
.Attr("num_features: int >= 0") // Inferred. .Attr("num_features: int >= 0") // Inferred.
.Attr("logits_dimension: int = 1") .Attr("logits_dimension: int = 1")
.Attr("num_groups: int = 1") // Number of groups to process. .Attr("num_groups: int = 1") // Inferred; number of groups to process.
.SetShapeFn([](shape_inference::InferenceContext* c) { .SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle shape_handle;
int num_features; int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
int num_groups;
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
// Feature_ids, should be one for each feature.
shape_inference::ShapeHandle feature_ids_shape;
// TODO(crawles): remove 1 hardcode once kernel operates on multiple
// groups.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
TF_RETURN_IF_ERROR(
c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
int logits_dimension; int logits_dimension;
int num_groups;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
for (int i = 0; i < num_features; ++i) { TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
// num_features was kept for backwards compatibility reasons. It now
// represents number of groups.
DCHECK_EQ(num_features, num_groups);
shape_inference::ShapeHandle shape_handle;
for (int i = 0; i < num_groups; ++i) {
int offset = i + 1;
// Feature ids
TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle));
// Dimension ids. // Dimension ids.
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle)); TF_RETURN_IF_ERROR(
c->WithRank(c->input(offset + num_features), 1, &shape_handle));
// Node ids. // Node ids.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features + 2), 1, &shape_handle)); c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle));
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)}); auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
auto shape_rank_2 = auto shape_rank_2 =
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension}); c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
// Gains. // Gains.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle)); c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle));
// TODO(nponomareva): replace this with input("name",vector of shapes). // TODO(nponomareva): replace this with input("name",vector of shapes).
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2), TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3),
shape_rank_1, &shape_handle)); shape_rank_1, &shape_handle));
// Thresholds. // Thresholds.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 3 + 2), 1, &shape_handle)); c->WithRank(c->input(offset + num_features * 4), 1, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2), TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 4),
shape_rank_1, &shape_handle)); shape_rank_1, &shape_handle));
// Left and right node contribs. // Left and right node contribs.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle)); c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2), TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5),
shape_rank_2, &shape_handle)); shape_rank_2, &shape_handle));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 5 + 2), 2, &shape_handle)); c->WithRank(c->input(offset + num_features * 6), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 5 + 2), TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 6),
shape_rank_2, &shape_handle)); shape_rank_2, &shape_handle));
// Split types. // Split types.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 6 + 2), 1, &shape_handle)); c->WithRank(c->input(offset + num_features * 7), 1, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 6 + 2), TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 7),
shape_rank_1, &shape_handle)); shape_rank_1, &shape_handle));
} }
return Status::OK(); return Status::OK();