From ce127f779dbc6f9d65e17cc3c38f37a06ba666d0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Aug 2018 14:29:51 -0700 Subject: [PATCH] Fix the interaction between the split handler ops and the grow tree ensemble ops when we have nodes with no examples. PiperOrigin-RevId: 209830550 --- .../kernels/split_handler_ops.cc | 1 + .../boosted_trees/kernels/training_ops.cc | 43 +- .../batch/ordinal_split_handler_test.py | 71 +-- .../boosted_trees/proto/split_info.proto | 4 + .../python/kernel_tests/training_ops_test.py | 547 +++++++++++++++++- 5 files changed, 611 insertions(+), 55 deletions(-) diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 64349cfca39..3a486353193 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -400,6 +400,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const int start_index = partition_boundaries[root_idx]; (*output_partition_ids)(root_idx) = partition_ids(start_index); + oblivious_split_info.add_children_parent_id(partition_ids(start_index)); } oblivious_split_info.SerializeToString(&(*output_splits)(0)); } diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index bb5ae78d9bf..ab2853352a7 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= +#include + #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" @@ -772,20 +774,32 @@ class GrowTreeEnsembleOp : public OpKernel { // The number of new children. int num_children = 1 << (depth + 1); auto split_info = split->oblivious_split_info; - CHECK(num_children == split_info.children_size()) - << "Wrong number of new children: " << num_children - << " != " << split_info.children_size(); - for (int idx = 0; idx < num_children; idx += 2) { - // Old leaf is at position depth + idx / 2. + CHECK(num_children >= split_info.children_size()) + << "Too many new children, expected <= " << num_children << " and got " + << split_info.children_size(); + std::vector new_leaves; + new_leaves.reserve(num_children); + int next_id = 0; + for (int idx = 0; idx < num_children / 2; idx++) { trees::Leaf old_leaf = - *tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf(); - // Update left leaf. - *split_info.mutable_children(idx) = - *MergeLeafWeights(old_leaf, split_info.mutable_children(idx)); - // Update right leaf. - *split_info.mutable_children(idx + 1) = - *MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1)); + *tree_config->mutable_nodes(depth + idx)->mutable_leaf(); + // Check if a split was made for this leaf. + if (next_id < split_info.children_parent_id_size() && + depth + idx == split_info.children_parent_id(next_id)) { + // Add left leaf. + new_leaves.push_back(*MergeLeafWeights( + old_leaf, split_info.mutable_children(2 * next_id))); + // Add right leaf. + new_leaves.push_back(*MergeLeafWeights( + old_leaf, split_info.mutable_children(2 * next_id + 1))); + next_id++; + } else { + // If there is no split for this leaf, just duplicate it. + new_leaves.push_back(old_leaf); + new_leaves.push_back(old_leaf); + } } + CHECK(next_id == split_info.children_parent_id_size()); TreeNodeMetadata* split_metadata = split_info.mutable_split_node()->mutable_node_metadata(); split_metadata->set_gain(split->gain); @@ -804,11 +818,10 @@ class GrowTreeEnsembleOp : public OpKernel { if (idx + depth + 1 < nodes_size) { // Update leaves that were already there. *tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() = - *split_info.mutable_children(idx); + new_leaves[idx]; } else { // Add new leaves. - *tree_config->add_nodes()->mutable_leaf() = - *split_info.mutable_children(idx); + *tree_config->add_nodes()->mutable_leaf() = new_leaves[idx]; } } } diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index d9caebb645a..31043264a11 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -186,14 +186,14 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): with self.test_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | - # i0 | (0.2, 0.12) | 0 | 2 | - # i1 | (-0.5, 0.07) | 0 | 2 | - # i2 | (1.2, 0.2) | 0 | 0 | - # i3 | (4.0, 0.13) | 1 | 1 | + # i0 | (0.2, 0.12) | 1 | 2 | + # i1 | (-0.5, 0.07) | 1 | 2 | + # i2 | (1.2, 0.2) | 1 | 0 | + # i3 | (4.0, 0.13) | 2 | 1 | dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52]) gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) - partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32) class_id = -1 gradient_shape = tensor_shape.scalar() @@ -254,7 +254,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertFalse(are_splits_ready) self.assertTrue(are_splits_ready2) - self.assertAllEqual([0, 1], partitions) + self.assertAllEqual([1, 2], partitions) oblivious_split_info = split_info_pb2.ObliviousSplitInfo() oblivious_split_info.ParseFromString(splits[0]) @@ -263,52 +263,57 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 0.00001) self.assertEqual(0, split_node.feature_column) - # Check the split on partition 0. + # Check the split on partition 1. # -(1.2 - 0.1) / (0.2 + 1) - expected_left_weight_0 = -0.9166666666666666 + expected_left_weight_1 = -0.9166666666666666 - # expected_left_weight_0 * -(1.2 - 0.1) - expected_left_gain_0 = 1.008333333333333 + # expected_left_weight_1 * -(1.2 - 0.1) + expected_left_gain_1 = 1.008333333333333 # (-0.5 + 0.2 + 0.1) / (0.19 + 1) - expected_right_weight_0 = 0.1680672 + expected_right_weight_1 = 0.1680672 - # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1)) - expected_right_gain_0 = 0.033613445378151252 + # expected_right_weight_1 * -(-0.5 + 0.2 + 0.1)) + expected_right_gain_1 = 0.033613445378151252 # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) - expected_bias_gain_0 = 0.46043165467625896 + expected_bias_gain_1 = 0.46043165467625896 left_child = oblivious_split_info.children[0].vector right_child = oblivious_split_info.children[1].vector - self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) - - self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001) - - # Check the split on partition 1. - expected_left_weight_1 = 0 - expected_left_gain_1 = 0 - # -(4 - 0.1) / (0.13 + 1) - expected_right_weight_1 = -3.4513274336283186 - # expected_right_weight_1 * -(4 - 0.1) - expected_right_gain_1 = 13.460176991150442 - # (-4 + 0.1) ** 2 / (0.13 + 1) - expected_bias_gain_1 = 13.460176991150442 - - left_child = oblivious_split_info.children[2].vector - right_child = oblivious_split_info.children[3].vector - self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001) + # Check the split on partition 2. + expected_left_weight_2 = 0 + expected_left_gain_2 = 0 + # -(4 - 0.1) / (0.13 + 1) + expected_right_weight_2 = -3.4513274336283186 + # expected_right_weight_2 * -(4 - 0.1) + expected_right_gain_2 = 13.460176991150442 + # (-4 + 0.1) ** 2 / (0.13 + 1) + expected_bias_gain_2 = 13.460176991150442 + + left_child = oblivious_split_info.children[2].vector + right_child = oblivious_split_info.children[3].vector + + self.assertAllClose([expected_left_weight_2], left_child.value, 0.00001) + + self.assertAllClose([expected_right_weight_2], right_child.value, 0.00001) + # The layer gain is the sum of the gains of each partition layer_gain = ( - expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + ( - expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + ( + expected_left_gain_2 + expected_right_gain_2 - expected_bias_gain_2) self.assertAllClose(layer_gain, gains[0], 0.00001) + # We have examples in both partitions, then we get both ids. + self.assertEqual(2, len(oblivious_split_info.children_parent_id)) + self.assertEqual(1, oblivious_split_info.children_parent_id[0]) + self.assertEqual(2, oblivious_split_info.children_parent_id[1]) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): with self.test_session() as sess: # The data looks like the following: diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto index 65448996bff..784977af395 100644 --- a/tensorflow/contrib/boosted_trees/proto/split_info.proto +++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto @@ -21,4 +21,8 @@ message SplitInfo { message ObliviousSplitInfo { tensorflow.boosted_trees.trees.TreeNode split_node = 1; repeated tensorflow.boosted_trees.trees.Leaf children = 2; + // For each child, children_parent_id stores the node_id of its parent when it + // was a leaf. For the idx-th child it corresponds the idx/2-th + // children_parent_id. + repeated int32 children_parent_id = 3; } diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 278dc1f7560..b3e4c2e5f7a 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -91,7 +91,8 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight): return split.SerializeToString() -def _gen_dense_oblivious_split_info(fc, threshold, leave_weights): +def _gen_dense_oblivious_split_info(fc, threshold, leave_weights, + children_parent_id): split_str = """ split_node { oblivious_dense_float_binary_split { @@ -107,6 +108,9 @@ def _gen_dense_oblivious_split_info(fc, threshold, leave_weights): } }""" % ( weight) + for x in children_parent_id: + split_str += """ + children_parent_id: %d""" % (x) split = split_info_pb2.ObliviousSplitInfo() text_format.Merge(split_str, split) return split.SerializeToString() @@ -432,14 +436,18 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([7.62], dtype=np.float32) handler1_split = [ - _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143]) + _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143], [0]) ] handler2_partitions = np.array([0], dtype=np.int32) handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])] + handler2_split = [ + _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24], [0]) + ] handler3_partitions = np.array([0], dtype=np.int32) handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])] + handler3_split = [ + _gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143], [0]) + ] # Grow tree ensemble. grow_op = training_ops.grow_tree_ensemble( @@ -1675,17 +1683,20 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([1.4], dtype=np.float32) handler1_split = [ - _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5]) + _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5], + [1, 2]) ] handler2_partitions = np.array([0], dtype=np.int32) handler2_gains = np.array([2.7], dtype=np.float32) handler2_split = [ - _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]), + _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4], + [1, 2]) ] handler3_partitions = np.array([0], dtype=np.int32) handler3_gains = np.array([1.7], dtype=np.float32) handler3_split = [ - _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1]) + _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1], + [1, 2]) ] # Grow tree ensemble layer by layer. @@ -1797,6 +1808,528 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(stats.attempted_layers, 2) self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowEnsembleWithEmptyNodesMiddleCase(self): + """Test case: The middle existing leaves don't have examples.""" + with self.test_session() as session: + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 2 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=6, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER) + + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([1.8], dtype=np.float32) + handler1_split = [ + _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [2, 5]) + ] + # The tree currently has depth 2, so the ids for the four leaves are in + # the range [2, 6). In this test case we are assuming that our examples + # only fall in leaves 2 and 5. + + # Grow tree ensemble layer by layer. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=0.1, + partition_ids=[handler1_partitions], + gains=[handler1_gains], + splits=[handler1_split], + learner_config=learner_config.SerializeToString(), + dropout_seed=123, + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + session.run(grow_op) + + new_stamp, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + stats = session.run( + training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1)) + tree_ensemble_config.ParseFromString(serialized) + expected_result = """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 0 + threshold: 0.9 + } + node_metadata { + gain: 1.8 + original_oblivious_leaves { + vector { + value: 6.543 + } + } + original_oblivious_leaves { + vector { + value: 7.5 + } + } + original_oblivious_leaves { + vector { + value: -4.075 + } + } + original_oblivious_leaves { + vector { + value: -3.975 + } + } + } + } + nodes { + leaf { + vector { + value: 7.543 + } + } + } + nodes { + leaf { + vector { + value: 8.543 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -0.975 + } + } + } + nodes { + leaf { + vector { + value: 0.025 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 3 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 3 + } + """ + self.assertEqual(new_stamp, 1) + self.assertEqual(stats.num_trees, 0) + self.assertEqual(stats.num_layers, 3) + self.assertEqual(stats.active_tree, 1) + self.assertEqual(stats.active_layer, 3) + self.assertEqual(stats.attempted_trees, 1) + self.assertEqual(stats.attempted_layers, 3) + self.assertProtoEquals(expected_result, tree_ensemble_config) + + def testGrowEnsembleWithEmptyNodesBorderCase(self): + """Test case: The first and last existing leaves don't have examples.""" + with self.test_session() as session: + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 2 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=6, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER) + + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([1.8], dtype=np.float32) + handler1_split = [ + _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [3, 4]) + ] + # The tree currently has depth 2, so the ids for the four leaves are in + # the range [2, 6). In this test case we are assuming that our examples + # only fall in leaves 3 and 4. + + # Grow tree ensemble layer by layer. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=0.1, + partition_ids=[handler1_partitions], + gains=[handler1_gains], + splits=[handler1_split], + learner_config=learner_config.SerializeToString(), + dropout_seed=123, + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + session.run(grow_op) + + new_stamp, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + stats = session.run( + training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1)) + tree_ensemble_config.ParseFromString(serialized) + expected_result = """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 0 + threshold: 0.9 + } + node_metadata { + gain: 1.8 + original_oblivious_leaves { + vector { + value: 6.543 + } + } + original_oblivious_leaves { + vector { + value: 7.5 + } + } + original_oblivious_leaves { + vector { + value: -4.075 + } + } + original_oblivious_leaves { + vector { + value: -3.975 + } + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 8.5 + } + } + } + nodes { + leaf { + vector { + value: 9.5 + } + } + } + nodes { + leaf { + vector { + value: -1.075 + } + } + } + nodes { + leaf { + vector { + value: -0.075 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 3 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 3 + } + """ + self.assertEqual(new_stamp, 1) + self.assertEqual(stats.num_trees, 0) + self.assertEqual(stats.num_layers, 3) + self.assertEqual(stats.active_tree, 1) + self.assertEqual(stats.active_layer, 3) + self.assertEqual(stats.attempted_trees, 1) + self.assertEqual(stats.attempted_layers, 3) + self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowExistingEnsembleTreeFinalizedWithDropout(self): """Test growing an existing ensemble with the last tree finalized.""" with self.cached_session() as session: