BoostedTreesUpdateEnsembleV2 kernel logic works on list of feature_ids.
PiperOrigin-RevId: 290134014 Change-Id: I45ec10013c9271c432b94c8c8a84fef65a65e373
This commit is contained in:
parent
ef8379e48b
commit
fa657fb523
@ -285,7 +285,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
|
||||
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
|
||||
left_node_id, right_node_id);
|
||||
auto* new_split = node->mutable_bucketized_split();
|
||||
new_split->set_feature_id(candidate.feature_idx);
|
||||
new_split->set_feature_id(candidate.feature_id);
|
||||
new_split->set_threshold(candidate.threshold);
|
||||
new_split->set_dimension_id(candidate.dimension_id);
|
||||
new_split->set_left_id(*left_node_id);
|
||||
@ -310,7 +310,7 @@ void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
|
||||
auto* node = AddLeafNodes(tree_id, split_entry, logits_dimension,
|
||||
left_node_id, right_node_id);
|
||||
auto* new_split = node->mutable_categorical_split();
|
||||
new_split->set_feature_id(candidate.feature_idx);
|
||||
new_split->set_feature_id(candidate.feature_id);
|
||||
new_split->set_value(candidate.threshold);
|
||||
new_split->set_dimension_id(candidate.dimension_id);
|
||||
new_split->set_left_id(*left_node_id);
|
||||
|
@ -189,10 +189,9 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
||||
// Get current split candidate.
|
||||
const auto& node_id = node_ids(candidate_idx);
|
||||
const auto& gain = gains(candidate_idx);
|
||||
|
||||
auto best_split_it = best_split_per_node->find(node_id);
|
||||
const auto& best_split_it = best_split_per_node->find(node_id);
|
||||
boosted_trees::SplitCandidate candidate;
|
||||
candidate.feature_idx = feature_ids(feature_idx);
|
||||
candidate.feature_id = feature_ids(feature_idx);
|
||||
candidate.candidate_idx = candidate_idx;
|
||||
candidate.gain = gain;
|
||||
candidate.dimension_id = 0;
|
||||
@ -207,8 +206,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
||||
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
||||
GainsAreEqual(gain, best_split_it->second.gain))) {
|
||||
const auto best_candidate = (*best_split_per_node)[node_id];
|
||||
const int32 best_feature_id = best_candidate.feature_idx;
|
||||
const int32 feature_id = candidate.feature_idx;
|
||||
const int32 best_feature_id = best_candidate.feature_id;
|
||||
const int32 feature_id = candidate.feature_id;
|
||||
VLOG(2) << "Breaking ties on feature ids and buckets";
|
||||
// Breaking ties deterministically.
|
||||
if (feature_id < best_feature_id) {
|
||||
@ -235,8 +234,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
public:
|
||||
explicit BoostedTreesUpdateEnsembleV2Op(OpKernelConstruction* const context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_features", &num_features_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("logits_dimension", &logits_dim_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_groups", &num_groups_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* const context) override {
|
||||
@ -272,8 +271,6 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
OpInputList feature_ids_list;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->input_list("feature_ids", &feature_ids_list));
|
||||
// TODO(crawles): Read groups of feature ids and find best splits among all.
|
||||
const auto feature_ids = feature_ids_list[0].vec<int32>();
|
||||
|
||||
const Tensor* max_depth_t;
|
||||
OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
|
||||
@ -292,7 +289,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
|
||||
thresholds_list, dimension_ids_list,
|
||||
left_node_contribs_list, right_node_contribs_list,
|
||||
split_types_list, feature_ids, &best_splits);
|
||||
split_types_list, feature_ids_list, &best_splits);
|
||||
|
||||
int32 current_tree =
|
||||
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
|
||||
@ -395,38 +392,36 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
const OpInputList& thresholds_list, const OpInputList& dimension_ids_list,
|
||||
const OpInputList& left_node_contribs_list,
|
||||
const OpInputList& right_node_contribs_list,
|
||||
const OpInputList& split_types_list,
|
||||
const TTypes<const int32>::Vec& feature_ids,
|
||||
const OpInputList& split_types_list, const OpInputList& feature_ids_list,
|
||||
std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
|
||||
// Find best split per node going through every feature candidate.
|
||||
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
|
||||
const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
|
||||
const auto& gains = gains_list[feature_idx].vec<float>();
|
||||
const auto& thresholds = thresholds_list[feature_idx].vec<int32>();
|
||||
const auto& dimension_ids = dimension_ids_list[feature_idx].vec<int32>();
|
||||
for (int64 group_idx = 0; group_idx < num_groups_; ++group_idx) {
|
||||
const auto& node_ids = node_ids_list[group_idx].vec<int32>();
|
||||
const auto& gains = gains_list[group_idx].vec<float>();
|
||||
const auto& feature_ids = feature_ids_list[group_idx].vec<int32>();
|
||||
const auto& thresholds = thresholds_list[group_idx].vec<int32>();
|
||||
const auto& dimension_ids = dimension_ids_list[group_idx].vec<int32>();
|
||||
const auto& left_node_contribs =
|
||||
left_node_contribs_list[feature_idx].matrix<float>();
|
||||
left_node_contribs_list[group_idx].matrix<float>();
|
||||
const auto& right_node_contribs =
|
||||
right_node_contribs_list[feature_idx].matrix<float>();
|
||||
const auto& split_types = split_types_list[feature_idx].vec<tstring>();
|
||||
right_node_contribs_list[group_idx].matrix<float>();
|
||||
const auto& split_types = split_types_list[group_idx].vec<tstring>();
|
||||
|
||||
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
|
||||
++candidate_idx) {
|
||||
// Get current split candidate.
|
||||
const auto& node_id = node_ids(candidate_idx);
|
||||
const auto& gain = gains(candidate_idx);
|
||||
const auto& threshold = thresholds(candidate_idx);
|
||||
const auto& dimension_id = dimension_ids(candidate_idx);
|
||||
const auto& split_type = split_types(candidate_idx);
|
||||
const auto& feature_id = feature_ids(candidate_idx);
|
||||
|
||||
auto best_split_it = best_split_per_node->find(node_id);
|
||||
boosted_trees::SplitCandidate candidate;
|
||||
candidate.feature_idx = feature_ids(feature_idx);
|
||||
candidate.candidate_idx = candidate_idx;
|
||||
candidate.gain = gain;
|
||||
candidate.threshold = threshold;
|
||||
candidate.dimension_id = dimension_id;
|
||||
candidate.split_type = split_type;
|
||||
candidate.feature_id = feature_id;
|
||||
candidate.threshold = thresholds(candidate_idx);
|
||||
candidate.dimension_id = dimension_ids(candidate_idx);
|
||||
candidate.split_type = split_types(candidate_idx);
|
||||
for (int i = 0; i < logits_dim_; ++i) {
|
||||
candidate.left_node_contribs.push_back(
|
||||
learning_rate * left_node_contribs(candidate_idx, i));
|
||||
@ -435,9 +430,9 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
}
|
||||
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
||||
GainsAreEqual(gain, best_split_it->second.gain))) {
|
||||
const auto best_candidate = (*best_split_per_node)[node_id];
|
||||
const int32 best_feature_id = best_candidate.feature_idx;
|
||||
const int32 feature_id = candidate.feature_idx;
|
||||
const auto& best_candidate = (*best_split_per_node)[node_id];
|
||||
const int32 best_feature_id = best_candidate.feature_id;
|
||||
const int32 feature_id = candidate.feature_id;
|
||||
VLOG(2) << "Breaking ties on feature ids and buckets";
|
||||
// Breaking ties deterministically.
|
||||
if (feature_id < best_feature_id) {
|
||||
@ -452,8 +447,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
int32 num_features_;
|
||||
int32 logits_dim_;
|
||||
int32 num_groups_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsembleV2").Device(DEVICE_CPU),
|
||||
|
@ -30,12 +30,10 @@ namespace boosted_trees {
|
||||
struct SplitCandidate {
|
||||
SplitCandidate() {}
|
||||
|
||||
// Index in the list of the feature ids.
|
||||
int64 feature_idx = 0;
|
||||
|
||||
// Index in the tensor of node_ids for the feature with idx feature_idx.
|
||||
int64 candidate_idx = 0;
|
||||
|
||||
int64 feature_id = 0;
|
||||
float gain = 0.0;
|
||||
int32 threshold = 0.0;
|
||||
int32 dimension_id = 0;
|
||||
@ -56,20 +54,20 @@ static bool GainIsLarger(const float g1, const float g2) {
|
||||
return g1 - g2 >= kTolerance;
|
||||
}
|
||||
|
||||
static void MultiDimLogitSolveForWeightAndGain(Eigen::MatrixXf hessian_and_reg,
|
||||
Eigen::VectorXf g,
|
||||
Eigen::VectorXf* weight,
|
||||
float* gain) {
|
||||
static void MultiDimLogitSolveForWeightAndGain(
|
||||
const Eigen::MatrixXf& hessian_and_reg, const Eigen::VectorXf& g,
|
||||
Eigen::VectorXf* weight, float* gain) {
|
||||
*weight = -hessian_and_reg.colPivHouseholderQr().solve(g);
|
||||
*gain = -g.transpose() * (*weight);
|
||||
}
|
||||
|
||||
static void CalculateWeightsAndGains(const Eigen::VectorXf g,
|
||||
const Eigen::VectorXf h, const float l1,
|
||||
// Used in stats_ops.cc to determine weights/gains for each feature split.
|
||||
static void CalculateWeightsAndGains(const Eigen::VectorXf& g,
|
||||
const Eigen::VectorXf& h, const float l1,
|
||||
const float l2, Eigen::VectorXf* weight,
|
||||
float* gain) {
|
||||
const float kEps = 1e-15;
|
||||
int32 logits_dim = g.size();
|
||||
const int32 logits_dim = g.size();
|
||||
if (logits_dim == 1) {
|
||||
// The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
|
||||
// (g+l1*sgn(w))^2/(h+l2).
|
||||
|
@ -631,60 +631,61 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
|
||||
.Input("pruning_mode: int32")
|
||||
.Attr("num_features: int >= 0") // Inferred.
|
||||
.Attr("logits_dimension: int = 1")
|
||||
.Attr("num_groups: int = 1") // Number of groups to process.
|
||||
.Attr("num_groups: int = 1") // Inferred; number of groups to process.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
int num_features;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
|
||||
int num_groups;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
|
||||
|
||||
// Feature_ids, should be one for each feature.
|
||||
shape_inference::ShapeHandle feature_ids_shape;
|
||||
// TODO(crawles): remove 1 hardcode once kernel operates on multiple
|
||||
// groups.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
|
||||
|
||||
int logits_dimension;
|
||||
int num_groups;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||
for (int i = 0; i < num_features; ++i) {
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
|
||||
// num_features was kept for backwards compatibility reasons. It now
|
||||
// represents number of groups.
|
||||
DCHECK_EQ(num_features, num_groups);
|
||||
shape_inference::ShapeHandle shape_handle;
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int offset = i + 1;
|
||||
// Feature ids
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(offset), 1, &shape_handle));
|
||||
|
||||
// Dimension ids.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2), 1, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(offset + num_features), 1, &shape_handle));
|
||||
|
||||
// Node ids.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(i + num_features + 2), 1, &shape_handle));
|
||||
c->WithRank(c->input(offset + num_features * 2), 1, &shape_handle));
|
||||
auto shape_rank_1 = c->MakeShape({c->Dim(shape_handle, 0)});
|
||||
auto shape_rank_2 =
|
||||
c->MakeShape({c->Dim(shape_handle, 0), logits_dimension});
|
||||
|
||||
// Gains.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(i + num_features * 2 + 2), 1, &shape_handle));
|
||||
c->WithRank(c->input(offset + num_features * 3), 1, &shape_handle));
|
||||
// TODO(nponomareva): replace this with input("name",vector of shapes).
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 2 + 2),
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 3),
|
||||
shape_rank_1, &shape_handle));
|
||||
|
||||
// Thresholds.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(i + num_features * 3 + 2), 1, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 3 + 2),
|
||||
c->WithRank(c->input(offset + num_features * 4), 1, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 4),
|
||||
shape_rank_1, &shape_handle));
|
||||
|
||||
// Left and right node contribs.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(i + num_features * 4 + 2), 2, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 4 + 2),
|
||||
c->WithRank(c->input(offset + num_features * 5), 2, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 5),
|
||||
shape_rank_2, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(i + num_features * 5 + 2), 2, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 5 + 2),
|
||||
c->WithRank(c->input(offset + num_features * 6), 2, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 6),
|
||||
shape_rank_2, &shape_handle));
|
||||
|
||||
// Split types.
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(i + num_features * 6 + 2), 1, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(i + num_features * 6 + 2),
|
||||
c->WithRank(c->input(offset + num_features * 7), 1, &shape_handle));
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->input(offset + num_features * 7),
|
||||
shape_rank_1, &shape_handle));
|
||||
}
|
||||
return Status::OK();
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user