Adding more checks for categorical splits.

PiperOrigin-RevId: 223366746
This commit is contained in:
A. Unique TensorFlower 2018-11-29 10:16:39 -08:00 committed by TensorFlower Gardener
parent 51d694e20b
commit d2253ab518

View File

@ -834,8 +834,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
root_gradient_stats *= normalizer_ratio;
NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_feature_idx = 0;
bool best_feature_updated = false;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
CHECK(end_index - start_index >= 2)
<< "Partition should have a non bias feature. Start index "
<< start_index << " and end index " << end_index;
for (int64 feature_idx = start_index + 1; feature_idx < end_index;
++feature_idx) {
GradientStats left_gradient_stats(*gradients_t, *hessians_t,
@ -845,11 +850,13 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
root_gradient_stats - left_gradient_stats;
NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
if (!best_feature_updated ||
left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
best_right_node_stats = right_stats;
best_feature_idx = feature_idx;
best_feature_updated = true;
}
}
SplitInfo split_info;
@ -864,7 +871,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
<< feature_ids(best_feature_idx, 0) << ", "
<< feature_ids(best_feature_idx, 1)
<< "\nPartition IDS: " << partition_ids(start_index) << " "
<< partition_ids(best_feature_idx);
<< partition_ids(best_feature_idx) << " and best gain " << best_gain;
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();