Adding more checks for categorical splits.
PiperOrigin-RevId: 223366746
This commit is contained in:
parent
51d694e20b
commit
d2253ab518
@ -834,8 +834,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
|
||||
root_gradient_stats *= normalizer_ratio;
|
||||
NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
|
||||
int32 best_feature_idx = 0;
|
||||
bool best_feature_updated = false;
|
||||
NodeStats best_right_node_stats(0);
|
||||
NodeStats best_left_node_stats(0);
|
||||
CHECK(end_index - start_index >= 2)
|
||||
<< "Partition should have a non bias feature. Start index "
|
||||
<< start_index << " and end index " << end_index;
|
||||
|
||||
for (int64 feature_idx = start_index + 1; feature_idx < end_index;
|
||||
++feature_idx) {
|
||||
GradientStats left_gradient_stats(*gradients_t, *hessians_t,
|
||||
@ -845,11 +850,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
|
||||
root_gradient_stats - left_gradient_stats;
|
||||
NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
|
||||
NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
|
||||
if (left_stats.gain + right_stats.gain > best_gain) {
|
||||
if (!best_feature_updated ||
|
||||
left_stats.gain + right_stats.gain > best_gain) {
|
||||
best_gain = left_stats.gain + right_stats.gain;
|
||||
best_left_node_stats = left_stats;
|
||||
best_right_node_stats = right_stats;
|
||||
best_feature_idx = feature_idx;
|
||||
best_feature_updated = true;
|
||||
}
|
||||
}
|
||||
SplitInfo split_info;
|
||||
@ -864,7 +871,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
|
||||
<< feature_ids(best_feature_idx, 0) << ", "
|
||||
<< feature_ids(best_feature_idx, 1)
|
||||
<< "\nPartition IDS: " << partition_ids(start_index) << " "
|
||||
<< partition_ids(best_feature_idx);
|
||||
<< partition_ids(best_feature_idx) << " and best gain " << best_gain;
|
||||
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
|
||||
auto* left_child = split_info.mutable_left_child();
|
||||
auto* right_child = split_info.mutable_right_child();
|
||||
|
Loading…
x
Reference in New Issue
Block a user