BoostedTreesUpdateEnsembleV2 kernel logic works on list of feature_ids.

PiperOrigin-RevId: 290134014
Change-Id: I45ec10013c9271c432b94c8c8a84fef65a65e373
This commit is contained in:
A. Unique TensorFlower 2020-01-16 13:35:13 -08:00 committed by TensorFlower Gardener
parent ef8379e48b
commit fa657fb523
5 changed files with 478 additions and 493 deletions

View File

@ -285,7 +285,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
left_node_id, right_node_id);
auto* new_split = node->mutable_bucketized_split();
new_split->set_feature_id(candidate.feature_idx);
new_split->set_feature_id(candidate.feature_id);
new_split->set_threshold(candidate.threshold);
new_split->set_dimension_id(candidate.dimension_id);
new_split->set_left_id(*left_node_id);
@ -310,7 +310,7 @@ void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
left_node_id, right_node_id);
auto* new_split = node->mutable_categorical_split();
new_split->set_feature_id(candidate.feature_idx);
new_split->set_feature_id(candidate.feature_id);
new_split->set_value(candidate.threshold);
new_split->set_dimension_id(candidate.dimension_id);
new_split->set_left_id(*left_node_id);

View File

@ -189,10 +189,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// Get current split candidate.
const auto& node_id = node_ids(candidate_idx);
const auto& gain = gains(candidate_idx);
auto best_split_it = best_split_per_node->find(node_id);
const auto& best_split_it = best_split_per_node->find(node_id);
boosted_trees::SplitCandidate candidate;
candidate.feature_idx = feature_ids(feature_idx);
candidate.feature_id = feature_ids(feature_idx);
candidate.candidate_idx = candidate_idx;
candidate.gain = gain;
candidate.dimension_id = 0;
@ -207,8 +206,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
GainsAreEqual(gain, best_split_it->second.gain))) {
const auto best_candidate = (*best_split_per_node)[node_id];
const int32 best_feature_id = best_candidate.feature_idx;
const int32 feature_id = candidate.feature_idx;
const int32 best_feature_id = best_candidate.feature_id;
const int32 feature_id = candidate.feature_id;
VLOG(2) << "Breaking ties on feature ids and buckets";
// Breaking ties deterministically.
if (feature_id < best_feature_id) {
@ -235,8 +234,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
public:
explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context)
: OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
OP_REQUIRES_OK(context, context->GetAttr("num_groups", &num_groups_));
}
void Compute(OpKernelContext* const context) override {
@ -272,8 +271,6 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
OpInputList feature_ids_list;
OP_REQUIRES_OK(context,
context->input_list("feature_ids", &feature_ids_list));
// TODO(crawles): Read groups of feature ids and find best splits among all.
const auto feature_ids = feature_ids_list[0].vec<int32>();
const Tensor* max_depth_t;
OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
@ -292,7 +289,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
thresholds_list, dimension_ids_list,
left_node_contribs_list, right_node_contribs_list,
split_types_list, feature_ids, &best_splits);
split_types_list, feature_ids_list, &best_splits);
int32 current_tree =
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
@ -395,38 +392,36 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
const OpInputList& thresholds_list, const OpInputList& dimension_ids_list,
const OpInputList& left_node_contribs_list,
const OpInputList& right_node_contribs_list,
const OpInputList& split_types_list,
const TTypes<const int32>::Vec& feature_ids,
const OpInputList& split_types_list, const OpInputList& feature_ids_list,
std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
// Find best split per node going through every feature candidate.
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
const auto& gains = gains_list[feature_idx].vec<float>();
const auto& thresholds = thresholds_list[feature_idx].vec<int32>();
const auto& dimension_ids = dimension_ids_list[feature_idx].vec<int32>();
for (int64 group_idx = 0; group_idx < num_groups_; ++group_idx) {
const auto& node_ids = node_ids_list[group_idx].vec<int32>();
const auto& gains = gains_list[group_idx].vec<float>();
const auto& feature_ids = feature_ids_list[group_idx].vec<int32>();
const auto& thresholds = thresholds_list[group_idx].vec<int32>();
const auto& dimension_ids = dimension_ids_list[group_idx].vec<int32>();
const auto& left_node_contribs =
left_node_contribs_list[feature_idx].matrix<float>();
left_node_contribs_list[group_idx].matrix<float>();
const auto& right_node_contribs =
right_node_contribs_list[feature_idx].matrix<float>();
const auto& split_types = split_types_list[feature_idx].vec<tstring>();
right_node_contribs_list[group_idx].matrix<float>();
const auto& split_types = split_types_list[group_idx].vec<tstring>();
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
++candidate_idx) {
// Get current split candidate.
const auto& node_id = node_ids(candidate_idx);
const auto& gain = gains(candidate_idx);
const auto& threshold = thresholds(candidate_idx);
const auto& dimension_id = dimension_ids(candidate_idx);
const auto& split_type = split_types(candidate_idx);
const auto& feature_id = feature_ids(candidate_idx);
auto best_split_it = best_split_per_node->find(node_id);
boosted_trees::SplitCandidate candidate;
candidate.feature_idx = feature_ids(feature_idx);
candidate.candidate_idx = candidate_idx;
candidate.gain = gain;
candidate.threshold = threshold;
candidate.dimension_id = dimension_id;
candidate.split_type = split_type;
candidate.feature_id = feature_id;
candidate.threshold = thresholds(candidate_idx);
candidate.dimension_id = dimension_ids(candidate_idx);
candidate.split_type = split_types(candidate_idx);
for (int i = 0; i < logits_dim_; ++i) {
candidate.left_node_contribs.push_back(
learning_rate * left_node_contribs(candidate_idx, i));
@ -435,9 +430,9 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
}
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
GainsAreEqual(gain, best_split_it->second.gain))) {
const auto best_candidate = (*best_split_per_node)[node_id];
const int32 best_feature_id = best_candidate.feature_idx;
const int32 feature_id = candidate.feature_idx;
const auto& best_candidate = (*best_split_per_node)[node_id];
const int32 best_feature_id = best_candidate.feature_id;
const int32 feature_id = candidate.feature_id;
VLOG(2) << "Breaking ties on feature ids and buckets";
// Breaking ties deterministically.
if (feature_id < best_feature_id) {
@ -452,8 +447,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
}
private:
int32 num_features_;
int32 logits_dim_;
int32 num_groups_;
};
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU),

View File

@ -30,12 +30,10 @@ namespace boosted_trees {
struct SplitCandidate {
SplitCandidate() {}
// Index in the list of the feature ids.
int64 feature_idx = 0;
// Index in the tensor of node_ids for the feature with idx feature_idx.
int64 candidate_idx = 0;
int64 feature_id = 0;
float gain = 0.0;
int32 threshold = 0.0;
int32 dimension_id = 0;
@ -56,20 +54,20 @@ static bool GainIsLarger(const float g1, const float g2) {
return g1 - g2 >= kTolerance;
}
static void MultiDimLogitSolveForWeightAndGain(Eigen::MatrixXf hessian_and_reg,
Eigen::VectorXf g,
Eigen::VectorXf* weight,
float* gain) {
static void MultiDimLogitSolveForWeightAndGain(
const Eigen::MatrixXf& hessian_and_reg, const Eigen::VectorXf& g,
Eigen::VectorXf* weight, float* gain) {
*weight = -hessian_and_reg.colPivHouseholderQr().solve(g);
*gain = -g.transpose() * (*weight);
}
static void CalculateWeightsAndGains(const Eigen::VectorXf g,
const Eigen::VectorXf h, const float l1,
// Used in stats_ops.cc to determine weights/gains for each feature split.
static void CalculateWeightsAndGains(const Eigen::VectorXf& g,
const Eigen::VectorXf& h, const float l1,
const float l2, Eigen::VectorXf* weight,
float* gain) {
const float kEps = 1e-15;
int32 logits_dim = g.size();
const int32 logits_dim = g.size();
if (logits_dim == 1) {
// The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
// (g+l1*sgn(w))^2/(h+l2).

View File

@ -631,60 +631,61 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
.Input("pruning_mode: int32")
.Attr("num_features: int >= 0") // Inferred.
.Attr("logits_dimension: int = 1")
.Attr("num_groups: int = 1") // Number of groups to process.
.Attr("num_groups: int = 1") // Inferred; number of groups to process.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle shape_handle;
int num_features;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
int num_groups;
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
// Feature_ids, should be one for each feature.
shape_inference::ShapeHandle feature_ids_shape;
// TODO(crawles): remove 1 hardcode once kernel operates on multiple
// groups.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
TF_RETURN_IF_ERROR(
c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
int logits_dimension;
int num_groups;
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
for (int i = 0; i < num_features; ++i) {
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
// num_features was kept for backwards compatibility reasons. It now
// represents number of groups.
DCHECK_EQ(num_features, num_groups);
shape_inference::ShapeHandle shape_handle;
for (int i = 0; i < num_groups; ++i) {
int offset = i + 1;
// Feature ids
TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle));
// Dimension ids.
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(offset + num_features), 1, &shape_handle));
// Node ids.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle));
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
auto shape_rank_2 =
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
// Gains.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle));
// TODO(nponomareva): replace this with input("name",vector of shapes).
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3),
shape_rank_1, &shape_handle));
// Thresholds.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 3 + 2), 1, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
c->WithRank(c->input(offset + num_features * 4), 1, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 4),
shape_rank_1, &shape_handle));
// Left and right node contribs.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5),
shape_rank_2, &shape_handle));
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 5 + 2), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 5 + 2),
c->WithRank(c->input(offset + num_features * 6), 2, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 6),
shape_rank_2, &shape_handle));
// Split types.
TF_RETURN_IF_ERROR(
c->WithRank(c->input(i + num_features * 6 + 2), 1, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 6 + 2),
c->WithRank(c->input(offset + num_features * 7), 1, &shape_handle));
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 7),
shape_rank_1, &shape_handle));
}
return Status::OK();