BoostedTreesUpdateEnsembleV2 kernel logic works on list of feature_ids.
PiperOrigin-RevId: 290134014 Change-Id: I45ec10013c9271c432b94c8c8a84fef65a65e373
This commit is contained in:
parent
ef8379e48b
commit
fa657fb523
@ -285,7 +285,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
|
|||||||
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
|
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
|
||||||
left_node_id, right_node_id);
|
left_node_id, right_node_id);
|
||||||
auto* new_split = node->mutable_bucketized_split();
|
auto* new_split = node->mutable_bucketized_split();
|
||||||
new_split->set_feature_id(candidate.feature_idx);
|
new_split->set_feature_id(candidate.feature_id);
|
||||||
new_split->set_threshold(candidate.threshold);
|
new_split->set_threshold(candidate.threshold);
|
||||||
new_split->set_dimension_id(candidate.dimension_id);
|
new_split->set_dimension_id(candidate.dimension_id);
|
||||||
new_split->set_left_id(*left_node_id);
|
new_split->set_left_id(*left_node_id);
|
||||||
@ -310,7 +310,7 @@ void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
|
|||||||
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
|
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
|
||||||
left_node_id, right_node_id);
|
left_node_id, right_node_id);
|
||||||
auto* new_split = node->mutable_categorical_split();
|
auto* new_split = node->mutable_categorical_split();
|
||||||
new_split->set_feature_id(candidate.feature_idx);
|
new_split->set_feature_id(candidate.feature_id);
|
||||||
new_split->set_value(candidate.threshold);
|
new_split->set_value(candidate.threshold);
|
||||||
new_split->set_dimension_id(candidate.dimension_id);
|
new_split->set_dimension_id(candidate.dimension_id);
|
||||||
new_split->set_left_id(*left_node_id);
|
new_split->set_left_id(*left_node_id);
|
||||||
|
@ -189,10 +189,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
// Get current split candidate.
|
// Get current split candidate.
|
||||||
const auto& node_id = node_ids(candidate_idx);
|
const auto& node_id = node_ids(candidate_idx);
|
||||||
const auto& gain = gains(candidate_idx);
|
const auto& gain = gains(candidate_idx);
|
||||||
|
const auto& best_split_it = best_split_per_node->find(node_id);
|
||||||
auto best_split_it = best_split_per_node->find(node_id);
|
|
||||||
boosted_trees::SplitCandidate candidate;
|
boosted_trees::SplitCandidate candidate;
|
||||||
candidate.feature_idx = feature_ids(feature_idx);
|
candidate.feature_id = feature_ids(feature_idx);
|
||||||
candidate.candidate_idx = candidate_idx;
|
candidate.candidate_idx = candidate_idx;
|
||||||
candidate.gain = gain;
|
candidate.gain = gain;
|
||||||
candidate.dimension_id = 0;
|
candidate.dimension_id = 0;
|
||||||
@ -207,8 +206,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
||||||
GainsAreEqual(gain, best_split_it->second.gain))) {
|
GainsAreEqual(gain, best_split_it->second.gain))) {
|
||||||
const auto best_candidate = (*best_split_per_node)[node_id];
|
const auto best_candidate = (*best_split_per_node)[node_id];
|
||||||
const int32 best_feature_id = best_candidate.feature_idx;
|
const int32 best_feature_id = best_candidate.feature_id;
|
||||||
const int32 feature_id = candidate.feature_idx;
|
const int32 feature_id = candidate.feature_id;
|
||||||
VLOG(2) << "Breaking ties on feature ids and buckets";
|
VLOG(2) << "Breaking ties on feature ids and buckets";
|
||||||
// Breaking ties deterministically.
|
// Breaking ties deterministically.
|
||||||
if (feature_id < best_feature_id) {
|
if (feature_id < best_feature_id) {
|
||||||
@ -235,8 +234,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context)
|
explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context)
|
||||||
: OpKernel(context) {
|
: OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
|
OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_groups", &num_groups_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* const context) override {
|
void Compute(OpKernelContext* const context) override {
|
||||||
@ -272,8 +271,6 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
OpInputList feature_ids_list;
|
OpInputList feature_ids_list;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->input_list("feature_ids", &feature_ids_list));
|
context->input_list("feature_ids", &feature_ids_list));
|
||||||
// TODO(crawles): Read groups of feature ids and find best splits among all.
|
|
||||||
const auto feature_ids = feature_ids_list[0].vec<int32>();
|
|
||||||
|
|
||||||
const Tensor* max_depth_t;
|
const Tensor* max_depth_t;
|
||||||
OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
|
OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
|
||||||
@ -292,7 +289,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
|
FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
|
||||||
thresholds_list, dimension_ids_list,
|
thresholds_list, dimension_ids_list,
|
||||||
left_node_contribs_list, right_node_contribs_list,
|
left_node_contribs_list, right_node_contribs_list,
|
||||||
split_types_list, feature_ids, &best_splits);
|
split_types_list, feature_ids_list, &best_splits);
|
||||||
|
|
||||||
int32 current_tree =
|
int32 current_tree =
|
||||||
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
|
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
|
||||||
@ -395,38 +392,36 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
const OpInputList& thresholds_list, const OpInputList& dimension_ids_list,
|
const OpInputList& thresholds_list, const OpInputList& dimension_ids_list,
|
||||||
const OpInputList& left_node_contribs_list,
|
const OpInputList& left_node_contribs_list,
|
||||||
const OpInputList& right_node_contribs_list,
|
const OpInputList& right_node_contribs_list,
|
||||||
const OpInputList& split_types_list,
|
const OpInputList& split_types_list, const OpInputList& feature_ids_list,
|
||||||
const TTypes<const int32>::Vec& feature_ids,
|
|
||||||
std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
|
std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
|
||||||
// Find best split per node going through every feature candidate.
|
// Find best split per node going through every feature candidate.
|
||||||
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
|
for (int64 group_idx = 0; group_idx < num_groups_; ++group_idx) {
|
||||||
const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
|
const auto& node_ids = node_ids_list[group_idx].vec<int32>();
|
||||||
const auto& gains = gains_list[feature_idx].vec<float>();
|
const auto& gains = gains_list[group_idx].vec<float>();
|
||||||
const auto& thresholds = thresholds_list[feature_idx].vec<int32>();
|
const auto& feature_ids = feature_ids_list[group_idx].vec<int32>();
|
||||||
const auto& dimension_ids = dimension_ids_list[feature_idx].vec<int32>();
|
const auto& thresholds = thresholds_list[group_idx].vec<int32>();
|
||||||
|
const auto& dimension_ids = dimension_ids_list[group_idx].vec<int32>();
|
||||||
const auto& left_node_contribs =
|
const auto& left_node_contribs =
|
||||||
left_node_contribs_list[feature_idx].matrix<float>();
|
left_node_contribs_list[group_idx].matrix<float>();
|
||||||
const auto& right_node_contribs =
|
const auto& right_node_contribs =
|
||||||
right_node_contribs_list[feature_idx].matrix<float>();
|
right_node_contribs_list[group_idx].matrix<float>();
|
||||||
const auto& split_types = split_types_list[feature_idx].vec<tstring>();
|
const auto& split_types = split_types_list[group_idx].vec<tstring>();
|
||||||
|
|
||||||
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
|
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
|
||||||
++candidate_idx) {
|
++candidate_idx) {
|
||||||
// Get current split candidate.
|
// Get current split candidate.
|
||||||
const auto& node_id = node_ids(candidate_idx);
|
const auto& node_id = node_ids(candidate_idx);
|
||||||
const auto& gain = gains(candidate_idx);
|
const auto& gain = gains(candidate_idx);
|
||||||
const auto& threshold = thresholds(candidate_idx);
|
const auto& feature_id = feature_ids(candidate_idx);
|
||||||
const auto& dimension_id = dimension_ids(candidate_idx);
|
|
||||||
const auto& split_type = split_types(candidate_idx);
|
|
||||||
|
|
||||||
auto best_split_it = best_split_per_node->find(node_id);
|
auto best_split_it = best_split_per_node->find(node_id);
|
||||||
boosted_trees::SplitCandidate candidate;
|
boosted_trees::SplitCandidate candidate;
|
||||||
candidate.feature_idx = feature_ids(feature_idx);
|
|
||||||
candidate.candidate_idx = candidate_idx;
|
candidate.candidate_idx = candidate_idx;
|
||||||
candidate.gain = gain;
|
candidate.gain = gain;
|
||||||
candidate.threshold = threshold;
|
candidate.feature_id = feature_id;
|
||||||
candidate.dimension_id = dimension_id;
|
candidate.threshold = thresholds(candidate_idx);
|
||||||
candidate.split_type = split_type;
|
candidate.dimension_id = dimension_ids(candidate_idx);
|
||||||
|
candidate.split_type = split_types(candidate_idx);
|
||||||
for (int i = 0; i < logits_dim_; ++i) {
|
for (int i = 0; i < logits_dim_; ++i) {
|
||||||
candidate.left_node_contribs.push_back(
|
candidate.left_node_contribs.push_back(
|
||||||
learning_rate * left_node_contribs(candidate_idx, i));
|
learning_rate * left_node_contribs(candidate_idx, i));
|
||||||
@ -435,9 +430,9 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
}
|
}
|
||||||
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
||||||
GainsAreEqual(gain, best_split_it->second.gain))) {
|
GainsAreEqual(gain, best_split_it->second.gain))) {
|
||||||
const auto best_candidate = (*best_split_per_node)[node_id];
|
const auto& best_candidate = (*best_split_per_node)[node_id];
|
||||||
const int32 best_feature_id = best_candidate.feature_idx;
|
const int32 best_feature_id = best_candidate.feature_id;
|
||||||
const int32 feature_id = candidate.feature_idx;
|
const int32 feature_id = candidate.feature_id;
|
||||||
VLOG(2) << "Breaking ties on feature ids and buckets";
|
VLOG(2) << "Breaking ties on feature ids and buckets";
|
||||||
// Breaking ties deterministically.
|
// Breaking ties deterministically.
|
||||||
if (feature_id < best_feature_id) {
|
if (feature_id < best_feature_id) {
|
||||||
@ -452,8 +447,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32 num_features_;
|
|
||||||
int32 logits_dim_;
|
int32 logits_dim_;
|
||||||
|
int32 num_groups_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU),
|
||||||
|
@ -30,12 +30,10 @@ namespace boosted_trees {
|
|||||||
struct SplitCandidate {
|
struct SplitCandidate {
|
||||||
SplitCandidate() {}
|
SplitCandidate() {}
|
||||||
|
|
||||||
// Index in the list of the feature ids.
|
|
||||||
int64 feature_idx = 0;
|
|
||||||
|
|
||||||
// Index in the tensor of node_ids for the feature with idx feature_idx.
|
// Index in the tensor of node_ids for the feature with idx feature_idx.
|
||||||
int64 candidate_idx = 0;
|
int64 candidate_idx = 0;
|
||||||
|
|
||||||
|
int64 feature_id = 0;
|
||||||
float gain = 0.0;
|
float gain = 0.0;
|
||||||
int32 threshold = 0.0;
|
int32 threshold = 0.0;
|
||||||
int32 dimension_id = 0;
|
int32 dimension_id = 0;
|
||||||
@ -56,20 +54,20 @@ static bool GainIsLarger(const float g1, const float g2) {
|
|||||||
return g1 - g2 >= kTolerance;
|
return g1 - g2 >= kTolerance;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void MultiDimLogitSolveForWeightAndGain(Eigen::MatrixXf hessian_and_reg,
|
static void MultiDimLogitSolveForWeightAndGain(
|
||||||
Eigen::VectorXf g,
|
const Eigen::MatrixXf& hessian_and_reg, const Eigen::VectorXf& g,
|
||||||
Eigen::VectorXf* weight,
|
Eigen::VectorXf* weight, float* gain) {
|
||||||
float* gain) {
|
|
||||||
*weight = -hessian_and_reg.colPivHouseholderQr().solve(g);
|
*weight = -hessian_and_reg.colPivHouseholderQr().solve(g);
|
||||||
*gain = -g.transpose() * (*weight);
|
*gain = -g.transpose() * (*weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void CalculateWeightsAndGains(const Eigen::VectorXf g,
|
// Used in stats_ops.cc to determine weights/gains for each feature split.
|
||||||
const Eigen::VectorXf h, const float l1,
|
static void CalculateWeightsAndGains(const Eigen::VectorXf& g,
|
||||||
|
const Eigen::VectorXf& h, const float l1,
|
||||||
const float l2, Eigen::VectorXf* weight,
|
const float l2, Eigen::VectorXf* weight,
|
||||||
float* gain) {
|
float* gain) {
|
||||||
const float kEps = 1e-15;
|
const float kEps = 1e-15;
|
||||||
int32 logits_dim = g.size();
|
const int32 logits_dim = g.size();
|
||||||
if (logits_dim == 1) {
|
if (logits_dim == 1) {
|
||||||
// The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
|
// The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
|
||||||
// (g+l1*sgn(w))^2/(h+l2).
|
// (g+l1*sgn(w))^2/(h+l2).
|
||||||
|
@ -631,60 +631,61 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
|
|||||||
.Input("pruning_mode: int32")
|
.Input("pruning_mode: int32")
|
||||||
.Attr("num_features: int >= 0") // Inferred.
|
.Attr("num_features: int >= 0") // Inferred.
|
||||||
.Attr("logits_dimension: int = 1")
|
.Attr("logits_dimension: int = 1")
|
||||||
.Attr("num_groups: int = 1") // Number of groups to process.
|
.Attr("num_groups: int = 1") // Inferred; number of groups to process.
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
shape_inference::ShapeHandle shape_handle;
|
|
||||||
int num_features;
|
int num_features;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
|
|
||||||
int num_groups;
|
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
|
|
||||||
|
|
||||||
// Feature_ids, should be one for each feature.
|
|
||||||
shape_inference::ShapeHandle feature_ids_shape;
|
|
||||||
// TODO(crawles): remove 1 hardcode once kernel operates on multiple
|
|
||||||
// groups.
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
|
|
||||||
|
|
||||||
int logits_dimension;
|
int logits_dimension;
|
||||||
|
int num_groups;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||||
for (int i = 0; i < num_features; ++i) {
|
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
|
||||||
|
// num_features was kept for backwards compatibility reasons. It now
|
||||||
|
// represents number of groups.
|
||||||
|
DCHECK_EQ(num_features, num_groups);
|
||||||
|
shape_inference::ShapeHandle shape_handle;
|
||||||
|
for (int i = 0; i < num_groups; ++i) {
|
||||||
|
int offset = i + 1;
|
||||||
|
// Feature ids
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle));
|
||||||
|
|
||||||
// Dimension ids.
|
// Dimension ids.
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->WithRank(c->input(offset + num_features), 1, &shape_handle));
|
||||||
|
|
||||||
// Node ids.
|
// Node ids.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
|
c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle));
|
||||||
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
|
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
|
||||||
auto shape_rank_2 =
|
auto shape_rank_2 =
|
||||||
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
|
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
|
||||||
|
|
||||||
// Gains.
|
// Gains.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
|
c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle));
|
||||||
// TODO(nponomareva): replace this with input("name",vector of shapes).
|
// TODO(nponomareva): replace this with input("name",vector of shapes).
|
||||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
|
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3),
|
||||||
shape_rank_1, &shape_handle));
|
shape_rank_1, &shape_handle));
|
||||||
|
|
||||||
// Thresholds.
|
// Thresholds.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->WithRank(c->input(i + num_features * 3 + 2), 1, &shape_handle));
|
c->WithRank(c->input(offset + num_features * 4), 1, &shape_handle));
|
||||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
|
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 4),
|
||||||
shape_rank_1, &shape_handle));
|
shape_rank_1, &shape_handle));
|
||||||
|
|
||||||
// Left and right node contribs.
|
// Left and right node contribs.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
|
c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle));
|
||||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
|
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5),
|
||||||
shape_rank_2, &shape_handle));
|
shape_rank_2, &shape_handle));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->WithRank(c->input(i + num_features * 5 + 2), 2, &shape_handle));
|
c->WithRank(c->input(offset + num_features * 6), 2, &shape_handle));
|
||||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 5 + 2),
|
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 6),
|
||||||
shape_rank_2, &shape_handle));
|
shape_rank_2, &shape_handle));
|
||||||
|
|
||||||
// Split types.
|
// Split types.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->WithRank(c->input(i + num_features * 6 + 2), 1, &shape_handle));
|
c->WithRank(c->input(offset + num_features * 7), 1, &shape_handle));
|
||||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 6 + 2),
|
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 7),
|
||||||
shape_rank_1, &shape_handle));
|
shape_rank_1, &shape_handle));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user