Merge pull request from bmc2-stripe:bmc/202003--leaf_ids

PiperOrigin-RevId: 341728992
Change-Id: Iac0a4c5fde134b968a913fdbae9046c47e3bfc06
This commit is contained in:
TensorFlower Gardener 2020-11-10 17:16:10 -08:00
commit 7da164e6c3
3 changed files with 26 additions and 4 deletions
tensorflow
core/kernels/boosted_trees
python/kernel_tests/boosted_trees

View File

@ -164,7 +164,7 @@ message TreeEnsemble {
// DebugOutput contains outputs useful for debugging/model interpretation, at // DebugOutput contains outputs useful for debugging/model interpretation, at
// the individual example-level. Debug outputs that are available to the user // the individual example-level. Debug outputs that are available to the user
// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble // are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble
// prediction path 3) Leaf node IDs. // prediction paths 3) Leaf node IDs.
message DebugOutput { message DebugOutput {
// Return the logits and associated feature splits across prediction paths for // Return the logits and associated feature splits across prediction paths for
// each tree, for every example, at predict time. We will use these values to // each tree, for every example, at predict time. We will use these values to
@ -173,7 +173,7 @@ message DebugOutput {
// id. // id.
repeated int32 feature_ids = 1; repeated int32 feature_ids = 1;
repeated float logits_path = 2; repeated float logits_path = 2;
// Return the node_id for each leaf node we reach in our prediction path.
// TODO(crawles): return 2) Node IDs for ensemble prediction path 3) Leaf node repeated int32 leaf_node_ids = 3;
// IDs. // TODO(crawles): return 4) Node IDs for ensemble prediction path
} }

View File

@ -361,6 +361,7 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
if (tree_id == 0 || node_id > 0) { if (tree_id == 0 || node_id > 0) {
past_trees_logit += tree_logit; past_trees_logit += tree_logit;
} }
example_debug_info.add_leaf_node_ids(node_id);
++tree_id; ++tree_id;
node_id = 0; node_id = 0;
} else { // Add to proto. } else { // Add to proto.

View File

@ -2673,6 +2673,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
# Expected logits are computed by traversing the logit path and # Expected logits are computed by traversing the logit path and
# subtracting child logits from parent logits. # subtracting child logits from parent logits.
bias = 1.72 * 0.1 # Root node of tree_0. bias = 1.72 * 0.1 # Root node of tree_0.
expected_leaf_ids = ((0,), (0,))
expected_feature_ids = ((), ()) expected_feature_ids = ((), ())
expected_logits_paths = ((bias,), (bias,)) expected_logits_paths = ((bias,), (bias,))
@ -2688,14 +2689,17 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
serialized_examples_debug_outputs = session.run(debug_op) serialized_examples_debug_outputs = session.run(debug_op)
feature_ids = [] feature_ids = []
logits_paths = [] logits_paths = []
leaf_ids = []
for example in serialized_examples_debug_outputs: for example in serialized_examples_debug_outputs:
example_debug_outputs = boosted_trees_pb2.DebugOutput() example_debug_outputs = boosted_trees_pb2.DebugOutput()
example_debug_outputs.ParseFromString(example) example_debug_outputs.ParseFromString(example)
feature_ids.append(example_debug_outputs.feature_ids) feature_ids.append(example_debug_outputs.feature_ids)
logits_paths.append(example_debug_outputs.logits_path) logits_paths.append(example_debug_outputs.logits_path)
leaf_ids.append(example_debug_outputs.leaf_node_ids)
self.assertAllClose(feature_ids, expected_feature_ids) self.assertAllClose(feature_ids, expected_feature_ids)
self.assertAllClose(logits_paths, expected_logits_paths) self.assertAllClose(logits_paths, expected_logits_paths)
self.assertAllClose(expected_leaf_ids, leaf_ids)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testContribsForOnlyABiasNodeMultiDimensionFeature(self): def testContribsForOnlyABiasNodeMultiDimensionFeature(self):
@ -2734,6 +2738,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
# Expected logits are computed by traversing the logit path and # Expected logits are computed by traversing the logit path and
# subtracting child logits from parent logits. # subtracting child logits from parent logits.
bias = 1.72 * 0.1 # Root node of tree_0. bias = 1.72 * 0.1 # Root node of tree_0.
expected_leaf_ids = ((0,), (0,))
expected_feature_ids = ((), ()) expected_feature_ids = ((), ())
expected_logits_paths = ((bias,), (bias,)) expected_logits_paths = ((bias,), (bias,))
@ -2749,14 +2754,17 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
serialized_examples_debug_outputs = session.run(debug_op) serialized_examples_debug_outputs = session.run(debug_op)
feature_ids = [] feature_ids = []
logits_paths = [] logits_paths = []
leaf_ids = []
for example in serialized_examples_debug_outputs: for example in serialized_examples_debug_outputs:
example_debug_outputs = boosted_trees_pb2.DebugOutput() example_debug_outputs = boosted_trees_pb2.DebugOutput()
example_debug_outputs.ParseFromString(example) example_debug_outputs.ParseFromString(example)
feature_ids.append(example_debug_outputs.feature_ids) feature_ids.append(example_debug_outputs.feature_ids)
logits_paths.append(example_debug_outputs.logits_path) logits_paths.append(example_debug_outputs.logits_path)
leaf_ids.append(example_debug_outputs.leaf_node_ids)
self.assertAllClose(feature_ids, expected_feature_ids) self.assertAllClose(feature_ids, expected_feature_ids)
self.assertAllClose(logits_paths, expected_logits_paths) self.assertAllClose(logits_paths, expected_logits_paths)
self.assertAllClose(leaf_ids, expected_leaf_ids)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self): def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
@ -2834,6 +2842,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
# example_0 : (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias) # example_0 : (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
# example_1 : (bias, 0.1 * 7. + bias ) # example_1 : (bias, 0.1 * 7. + bias )
expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42)) expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))
expected_leaf_ids = ((0, 3), (0, 2))
bucketized_features = [ bucketized_features = [
feature_0_values, feature_1_values, feature_2_values feature_0_values, feature_1_values, feature_2_values
@ -2847,14 +2856,18 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
serialized_examples_debug_outputs = session.run(debug_op) serialized_examples_debug_outputs = session.run(debug_op)
feature_ids = [] feature_ids = []
logits_paths = [] logits_paths = []
leaf_ids = []
for example in serialized_examples_debug_outputs: for example in serialized_examples_debug_outputs:
example_debug_outputs = boosted_trees_pb2.DebugOutput() example_debug_outputs = boosted_trees_pb2.DebugOutput()
example_debug_outputs.ParseFromString(example) example_debug_outputs.ParseFromString(example)
feature_ids.append(example_debug_outputs.feature_ids) feature_ids.append(example_debug_outputs.feature_ids)
logits_paths.append(example_debug_outputs.logits_path) logits_paths.append(example_debug_outputs.logits_path)
leaf_ids.append(example_debug_outputs.leaf_node_ids)
self.assertAllClose(feature_ids, expected_feature_ids) self.assertAllClose(feature_ids, expected_feature_ids)
self.assertAllClose(logits_paths, expected_logits_paths) self.assertAllClose(logits_paths, expected_logits_paths)
self.assertAllClose(leaf_ids, expected_leaf_ids)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testContribsMultipleTreeWhenFirstTreeIsABiasNodeMultiDimFeature(self): def testContribsMultipleTreeWhenFirstTreeIsABiasNodeMultiDimFeature(self):
@ -2933,6 +2946,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
# example_0 : (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias) # example_0 : (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
# example_1 : (bias, 0.1 * 7. + bias ) # example_1 : (bias, 0.1 * 7. + bias )
expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42)) expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))
expected_leaf_ids = ((0, 3), (0, 2))
bucketized_features = [ bucketized_features = [
feature_0_values, feature_1_values, feature_2_values feature_0_values, feature_1_values, feature_2_values
@ -2946,14 +2960,17 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
serialized_examples_debug_outputs = session.run(debug_op) serialized_examples_debug_outputs = session.run(debug_op)
feature_ids = [] feature_ids = []
logits_paths = [] logits_paths = []
leaf_ids = []
for example in serialized_examples_debug_outputs: for example in serialized_examples_debug_outputs:
example_debug_outputs = boosted_trees_pb2.DebugOutput() example_debug_outputs = boosted_trees_pb2.DebugOutput()
example_debug_outputs.ParseFromString(example) example_debug_outputs.ParseFromString(example)
feature_ids.append(example_debug_outputs.feature_ids) feature_ids.append(example_debug_outputs.feature_ids)
logits_paths.append(example_debug_outputs.logits_path) logits_paths.append(example_debug_outputs.logits_path)
leaf_ids.append(example_debug_outputs.leaf_node_ids)
self.assertAllClose(feature_ids, expected_feature_ids) self.assertAllClose(feature_ids, expected_feature_ids)
self.assertAllClose(logits_paths, expected_logits_paths) self.assertAllClose(logits_paths, expected_logits_paths)
self.assertAllClose(leaf_ids, expected_leaf_ids)
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testContribsMultipleTree(self): def testContribsMultipleTree(self):
@ -3075,6 +3092,7 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
# 1.0 * -7. + 0.2 * 7 + .114) # 1.0 * -7. + 0.2 * 7 + .114)
expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114), expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114),
(bias, 0.114, 1.514, -5.486)) (bias, 0.114, 1.514, -5.486))
expected_leaf_ids = ((1, 3, 2), (1, 2, 1))
bucketized_features = [ bucketized_features = [
feature_0_values, feature_1_values, feature_2_values feature_0_values, feature_1_values, feature_2_values
@ -3088,14 +3106,17 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
serialized_examples_debug_outputs = session.run(debug_op) serialized_examples_debug_outputs = session.run(debug_op)
feature_ids = [] feature_ids = []
logits_paths = [] logits_paths = []
leaf_ids = []
for example in serialized_examples_debug_outputs: for example in serialized_examples_debug_outputs:
example_debug_outputs = boosted_trees_pb2.DebugOutput() example_debug_outputs = boosted_trees_pb2.DebugOutput()
example_debug_outputs.ParseFromString(example) example_debug_outputs.ParseFromString(example)
feature_ids.append(example_debug_outputs.feature_ids) feature_ids.append(example_debug_outputs.feature_ids)
logits_paths.append(example_debug_outputs.logits_path) logits_paths.append(example_debug_outputs.logits_path)
leaf_ids.append(example_debug_outputs.leaf_node_ids)
self.assertAllClose(feature_ids, expected_feature_ids) self.assertAllClose(feature_ids, expected_feature_ids)
self.assertAllClose(logits_paths, expected_logits_paths) self.assertAllClose(logits_paths, expected_logits_paths)
self.assertAllClose(expected_leaf_ids, leaf_ids)
if __name__ == '__main__': if __name__ == '__main__':