Fix the interaction between the split handler ops and the grow tree ensemble ops when we have nodes with no examples.
PiperOrigin-RevId: 209830550
This commit is contained in:
parent
091c9809b8
commit
ce127f779d
@ -400,6 +400,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
|
||||
|
||||
const int start_index = partition_boundaries[root_idx];
|
||||
(*output_partition_ids)(root_idx) = partition_ids(start_index);
|
||||
oblivious_split_info.add_children_parent_id(partition_ids(start_index));
|
||||
}
|
||||
oblivious_split_info.SerializeToString(&(*output_splits)(0));
|
||||
}
|
||||
|
@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
|
||||
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
|
||||
@ -772,20 +774,32 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
// The number of new children.
|
||||
int num_children = 1 << (depth + 1);
|
||||
auto split_info = split->oblivious_split_info;
|
||||
CHECK(num_children == split_info.children_size())
|
||||
<< "Wrong number of new children: " << num_children
|
||||
<< " != " << split_info.children_size();
|
||||
for (int idx = 0; idx < num_children; idx += 2) {
|
||||
// Old leaf is at position depth + idx / 2.
|
||||
CHECK(num_children >= split_info.children_size())
|
||||
<< "Too many new children, expected <= " << num_children << " and got "
|
||||
<< split_info.children_size();
|
||||
std::vector<trees::Leaf> new_leaves;
|
||||
new_leaves.reserve(num_children);
|
||||
int next_id = 0;
|
||||
for (int idx = 0; idx < num_children / 2; idx++) {
|
||||
trees::Leaf old_leaf =
|
||||
*tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf();
|
||||
// Update left leaf.
|
||||
*split_info.mutable_children(idx) =
|
||||
*MergeLeafWeights(old_leaf, split_info.mutable_children(idx));
|
||||
// Update right leaf.
|
||||
*split_info.mutable_children(idx + 1) =
|
||||
*MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1));
|
||||
*tree_config->mutable_nodes(depth + idx)->mutable_leaf();
|
||||
// Check if a split was made for this leaf.
|
||||
if (next_id < split_info.children_parent_id_size() &&
|
||||
depth + idx == split_info.children_parent_id(next_id)) {
|
||||
// Add left leaf.
|
||||
new_leaves.push_back(*MergeLeafWeights(
|
||||
old_leaf, split_info.mutable_children(2 * next_id)));
|
||||
// Add right leaf.
|
||||
new_leaves.push_back(*MergeLeafWeights(
|
||||
old_leaf, split_info.mutable_children(2 * next_id + 1)));
|
||||
next_id++;
|
||||
} else {
|
||||
// If there is no split for this leaf, just duplicate it.
|
||||
new_leaves.push_back(old_leaf);
|
||||
new_leaves.push_back(old_leaf);
|
||||
}
|
||||
}
|
||||
CHECK(next_id == split_info.children_parent_id_size());
|
||||
TreeNodeMetadata* split_metadata =
|
||||
split_info.mutable_split_node()->mutable_node_metadata();
|
||||
split_metadata->set_gain(split->gain);
|
||||
@ -804,11 +818,10 @@ class GrowTreeEnsembleOp : public OpKernel {
|
||||
if (idx + depth + 1 < nodes_size) {
|
||||
// Update leaves that were already there.
|
||||
*tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
|
||||
*split_info.mutable_children(idx);
|
||||
new_leaves[idx];
|
||||
} else {
|
||||
// Add new leaves.
|
||||
*tree_config->add_nodes()->mutable_leaf() =
|
||||
*split_info.mutable_children(idx);
|
||||
*tree_config->add_nodes()->mutable_leaf() = new_leaves[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -186,14 +186,14 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
|
||||
with self.test_session() as sess:
|
||||
# The data looks like the following:
|
||||
# Example | Gradients | Partition | Dense Quantile |
|
||||
# i0 | (0.2, 0.12) | 0 | 2 |
|
||||
# i1 | (-0.5, 0.07) | 0 | 2 |
|
||||
# i2 | (1.2, 0.2) | 0 | 0 |
|
||||
# i3 | (4.0, 0.13) | 1 | 1 |
|
||||
# i0 | (0.2, 0.12) | 1 | 2 |
|
||||
# i1 | (-0.5, 0.07) | 1 | 2 |
|
||||
# i2 | (1.2, 0.2) | 1 | 0 |
|
||||
# i3 | (4.0, 0.13) | 2 | 1 |
|
||||
dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
|
||||
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
|
||||
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
|
||||
partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
|
||||
partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32)
|
||||
class_id = -1
|
||||
|
||||
gradient_shape = tensor_shape.scalar()
|
||||
@ -254,7 +254,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
|
||||
self.assertFalse(are_splits_ready)
|
||||
self.assertTrue(are_splits_ready2)
|
||||
|
||||
self.assertAllEqual([0, 1], partitions)
|
||||
self.assertAllEqual([1, 2], partitions)
|
||||
|
||||
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
|
||||
oblivious_split_info.ParseFromString(splits[0])
|
||||
@ -263,52 +263,57 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(0.3, split_node.threshold, 0.00001)
|
||||
self.assertEqual(0, split_node.feature_column)
|
||||
|
||||
# Check the split on partition 0.
|
||||
# Check the split on partition 1.
|
||||
# -(1.2 - 0.1) / (0.2 + 1)
|
||||
expected_left_weight_0 = -0.9166666666666666
|
||||
expected_left_weight_1 = -0.9166666666666666
|
||||
|
||||
# expected_left_weight_0 * -(1.2 - 0.1)
|
||||
expected_left_gain_0 = 1.008333333333333
|
||||
# expected_left_weight_1 * -(1.2 - 0.1)
|
||||
expected_left_gain_1 = 1.008333333333333
|
||||
|
||||
# (-0.5 + 0.2 + 0.1) / (0.19 + 1)
|
||||
expected_right_weight_0 = 0.1680672
|
||||
expected_right_weight_1 = 0.1680672
|
||||
|
||||
# expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
|
||||
expected_right_gain_0 = 0.033613445378151252
|
||||
# expected_right_weight_1 * -(-0.5 + 0.2 + 0.1))
|
||||
expected_right_gain_1 = 0.033613445378151252
|
||||
|
||||
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
|
||||
expected_bias_gain_0 = 0.46043165467625896
|
||||
expected_bias_gain_1 = 0.46043165467625896
|
||||
|
||||
left_child = oblivious_split_info.children[0].vector
|
||||
right_child = oblivious_split_info.children[1].vector
|
||||
|
||||
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
|
||||
|
||||
self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
|
||||
|
||||
# Check the split on partition 1.
|
||||
expected_left_weight_1 = 0
|
||||
expected_left_gain_1 = 0
|
||||
# -(4 - 0.1) / (0.13 + 1)
|
||||
expected_right_weight_1 = -3.4513274336283186
|
||||
# expected_right_weight_1 * -(4 - 0.1)
|
||||
expected_right_gain_1 = 13.460176991150442
|
||||
# (-4 + 0.1) ** 2 / (0.13 + 1)
|
||||
expected_bias_gain_1 = 13.460176991150442
|
||||
|
||||
left_child = oblivious_split_info.children[2].vector
|
||||
right_child = oblivious_split_info.children[3].vector
|
||||
|
||||
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
|
||||
|
||||
self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
|
||||
|
||||
# Check the split on partition 2.
|
||||
expected_left_weight_2 = 0
|
||||
expected_left_gain_2 = 0
|
||||
# -(4 - 0.1) / (0.13 + 1)
|
||||
expected_right_weight_2 = -3.4513274336283186
|
||||
# expected_right_weight_2 * -(4 - 0.1)
|
||||
expected_right_gain_2 = 13.460176991150442
|
||||
# (-4 + 0.1) ** 2 / (0.13 + 1)
|
||||
expected_bias_gain_2 = 13.460176991150442
|
||||
|
||||
left_child = oblivious_split_info.children[2].vector
|
||||
right_child = oblivious_split_info.children[3].vector
|
||||
|
||||
self.assertAllClose([expected_left_weight_2], left_child.value, 0.00001)
|
||||
|
||||
self.assertAllClose([expected_right_weight_2], right_child.value, 0.00001)
|
||||
|
||||
# The layer gain is the sum of the gains of each partition
|
||||
layer_gain = (
|
||||
expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
|
||||
expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
|
||||
expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + (
|
||||
expected_left_gain_2 + expected_right_gain_2 - expected_bias_gain_2)
|
||||
self.assertAllClose(layer_gain, gains[0], 0.00001)
|
||||
|
||||
# We have examples in both partitions, then we get both ids.
|
||||
self.assertEqual(2, len(oblivious_split_info.children_parent_id))
|
||||
self.assertEqual(1, oblivious_split_info.children_parent_id[0])
|
||||
self.assertEqual(2, oblivious_split_info.children_parent_id[1])
|
||||
|
||||
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
|
||||
with self.test_session() as sess:
|
||||
# The data looks like the following:
|
||||
|
@ -21,4 +21,8 @@ message SplitInfo {
|
||||
message ObliviousSplitInfo {
|
||||
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
|
||||
repeated tensorflow.boosted_trees.trees.Leaf children = 2;
|
||||
// For each child, children_parent_id stores the node_id of its parent when it
|
||||
// was a leaf. For the idx-th child it corresponds the idx/2-th
|
||||
// children_parent_id.
|
||||
repeated int32 children_parent_id = 3;
|
||||
}
|
||||
|
@ -91,7 +91,8 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
|
||||
return split.SerializeToString()
|
||||
|
||||
|
||||
def _gen_dense_oblivious_split_info(fc, threshold, leave_weights):
|
||||
def _gen_dense_oblivious_split_info(fc, threshold, leave_weights,
|
||||
children_parent_id):
|
||||
split_str = """
|
||||
split_node {
|
||||
oblivious_dense_float_binary_split {
|
||||
@ -107,6 +108,9 @@ def _gen_dense_oblivious_split_info(fc, threshold, leave_weights):
|
||||
}
|
||||
}""" % (
|
||||
weight)
|
||||
for x in children_parent_id:
|
||||
split_str += """
|
||||
children_parent_id: %d""" % (x)
|
||||
split = split_info_pb2.ObliviousSplitInfo()
|
||||
text_format.Merge(split_str, split)
|
||||
return split.SerializeToString()
|
||||
@ -432,14 +436,18 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
handler1_partitions = np.array([0], dtype=np.int32)
|
||||
handler1_gains = np.array([7.62], dtype=np.float32)
|
||||
handler1_split = [
|
||||
_gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143])
|
||||
_gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143], [0])
|
||||
]
|
||||
handler2_partitions = np.array([0], dtype=np.int32)
|
||||
handler2_gains = np.array([0.63], dtype=np.float32)
|
||||
handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])]
|
||||
handler2_split = [
|
||||
_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24], [0])
|
||||
]
|
||||
handler3_partitions = np.array([0], dtype=np.int32)
|
||||
handler3_gains = np.array([7.62], dtype=np.float32)
|
||||
handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])]
|
||||
handler3_split = [
|
||||
_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143], [0])
|
||||
]
|
||||
|
||||
# Grow tree ensemble.
|
||||
grow_op = training_ops.grow_tree_ensemble(
|
||||
@ -1675,17 +1683,20 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
handler1_partitions = np.array([0], dtype=np.int32)
|
||||
handler1_gains = np.array([1.4], dtype=np.float32)
|
||||
handler1_split = [
|
||||
_gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5])
|
||||
_gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5],
|
||||
[1, 2])
|
||||
]
|
||||
handler2_partitions = np.array([0], dtype=np.int32)
|
||||
handler2_gains = np.array([2.7], dtype=np.float32)
|
||||
handler2_split = [
|
||||
_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]),
|
||||
_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4],
|
||||
[1, 2])
|
||||
]
|
||||
handler3_partitions = np.array([0], dtype=np.int32)
|
||||
handler3_gains = np.array([1.7], dtype=np.float32)
|
||||
handler3_split = [
|
||||
_gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1])
|
||||
_gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1],
|
||||
[1, 2])
|
||||
]
|
||||
|
||||
# Grow tree ensemble layer by layer.
|
||||
@ -1797,6 +1808,528 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(stats.attempted_layers, 2)
|
||||
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
||||
|
||||
def testGrowEnsembleWithEmptyNodesMiddleCase(self):
|
||||
"""Test case: The middle existing leaves don't have examples."""
|
||||
with self.test_session() as session:
|
||||
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
text_format.Merge(
|
||||
"""
|
||||
trees {
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 4
|
||||
threshold: 7
|
||||
}
|
||||
node_metadata {
|
||||
gain: 7.62
|
||||
original_oblivious_leaves {
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 1
|
||||
threshold: 0.23
|
||||
}
|
||||
node_metadata {
|
||||
gain: 2.7
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 7.143
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -4.375
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 6.543
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 7.5
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -4.075
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -3.975
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_weights: 0.1
|
||||
tree_metadata {
|
||||
num_tree_weight_updates: 1
|
||||
num_layers_grown: 2
|
||||
}
|
||||
growing_metadata {
|
||||
num_trees_attempted: 1
|
||||
num_layers_attempted: 2
|
||||
}
|
||||
""", tree_ensemble_config)
|
||||
tree_ensemble_handle = model_ops.tree_ensemble_variable(
|
||||
stamp_token=0,
|
||||
tree_ensemble_config=tree_ensemble_config.SerializeToString(),
|
||||
name="tree_ensemble")
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
|
||||
# Prepare learner config.
|
||||
learner_config = _gen_learner_config(
|
||||
num_classes=2,
|
||||
l1_reg=0,
|
||||
l2_reg=0,
|
||||
tree_complexity=0,
|
||||
max_depth=6,
|
||||
min_node_weight=0,
|
||||
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
|
||||
growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
|
||||
|
||||
# Prepare handler inputs.
|
||||
handler1_partitions = np.array([0], dtype=np.int32)
|
||||
handler1_gains = np.array([1.8], dtype=np.float32)
|
||||
handler1_split = [
|
||||
_gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [2, 5])
|
||||
]
|
||||
# The tree currently has depth 2, so the ids for the four leaves are in
|
||||
# the range [2, 6). In this test case we are assuming that our examples
|
||||
# only fall in leaves 2 and 5.
|
||||
|
||||
# Grow tree ensemble layer by layer.
|
||||
grow_op = training_ops.grow_tree_ensemble(
|
||||
tree_ensemble_handle,
|
||||
stamp_token=0,
|
||||
next_stamp_token=1,
|
||||
learning_rate=0.1,
|
||||
partition_ids=[handler1_partitions],
|
||||
gains=[handler1_gains],
|
||||
splits=[handler1_split],
|
||||
learner_config=learner_config.SerializeToString(),
|
||||
dropout_seed=123,
|
||||
center_bias=True,
|
||||
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
|
||||
session.run(grow_op)
|
||||
|
||||
new_stamp, serialized = session.run(
|
||||
model_ops.tree_ensemble_serialize(tree_ensemble_handle))
|
||||
stats = session.run(
|
||||
training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
|
||||
tree_ensemble_config.ParseFromString(serialized)
|
||||
expected_result = """
|
||||
trees {
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 4
|
||||
threshold: 7
|
||||
}
|
||||
node_metadata {
|
||||
gain: 7.62
|
||||
original_oblivious_leaves {
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 1
|
||||
threshold: 0.23
|
||||
}
|
||||
node_metadata {
|
||||
gain: 2.7
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 7.143
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -4.375
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 0
|
||||
threshold: 0.9
|
||||
}
|
||||
node_metadata {
|
||||
gain: 1.8
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 6.543
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 7.5
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -4.075
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -3.975
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 7.543
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 8.543
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 7.5
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 7.5
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -4.075
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -4.075
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -0.975
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0.025
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_weights: 0.1
|
||||
tree_metadata {
|
||||
num_tree_weight_updates: 1
|
||||
num_layers_grown: 3
|
||||
}
|
||||
growing_metadata {
|
||||
num_trees_attempted: 1
|
||||
num_layers_attempted: 3
|
||||
}
|
||||
"""
|
||||
self.assertEqual(new_stamp, 1)
|
||||
self.assertEqual(stats.num_trees, 0)
|
||||
self.assertEqual(stats.num_layers, 3)
|
||||
self.assertEqual(stats.active_tree, 1)
|
||||
self.assertEqual(stats.active_layer, 3)
|
||||
self.assertEqual(stats.attempted_trees, 1)
|
||||
self.assertEqual(stats.attempted_layers, 3)
|
||||
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
||||
|
||||
def testGrowEnsembleWithEmptyNodesBorderCase(self):
|
||||
"""Test case: The first and last existing leaves don't have examples."""
|
||||
with self.test_session() as session:
|
||||
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||
text_format.Merge(
|
||||
"""
|
||||
trees {
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 4
|
||||
threshold: 7
|
||||
}
|
||||
node_metadata {
|
||||
gain: 7.62
|
||||
original_oblivious_leaves {
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 1
|
||||
threshold: 0.23
|
||||
}
|
||||
node_metadata {
|
||||
gain: 2.7
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 7.143
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -4.375
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 6.543
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 7.5
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -4.075
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -3.975
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_weights: 0.1
|
||||
tree_metadata {
|
||||
num_tree_weight_updates: 1
|
||||
num_layers_grown: 2
|
||||
}
|
||||
growing_metadata {
|
||||
num_trees_attempted: 1
|
||||
num_layers_attempted: 2
|
||||
}
|
||||
""", tree_ensemble_config)
|
||||
tree_ensemble_handle = model_ops.tree_ensemble_variable(
|
||||
stamp_token=0,
|
||||
tree_ensemble_config=tree_ensemble_config.SerializeToString(),
|
||||
name="tree_ensemble")
|
||||
resources.initialize_resources(resources.shared_resources()).run()
|
||||
|
||||
# Prepare learner config.
|
||||
learner_config = _gen_learner_config(
|
||||
num_classes=2,
|
||||
l1_reg=0,
|
||||
l2_reg=0,
|
||||
tree_complexity=0,
|
||||
max_depth=6,
|
||||
min_node_weight=0,
|
||||
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
|
||||
growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
|
||||
|
||||
# Prepare handler inputs.
|
||||
handler1_partitions = np.array([0], dtype=np.int32)
|
||||
handler1_gains = np.array([1.8], dtype=np.float32)
|
||||
handler1_split = [
|
||||
_gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [3, 4])
|
||||
]
|
||||
# The tree currently has depth 2, so the ids for the four leaves are in
|
||||
# the range [2, 6). In this test case we are assuming that our examples
|
||||
# only fall in leaves 3 and 4.
|
||||
|
||||
# Grow tree ensemble layer by layer.
|
||||
grow_op = training_ops.grow_tree_ensemble(
|
||||
tree_ensemble_handle,
|
||||
stamp_token=0,
|
||||
next_stamp_token=1,
|
||||
learning_rate=0.1,
|
||||
partition_ids=[handler1_partitions],
|
||||
gains=[handler1_gains],
|
||||
splits=[handler1_split],
|
||||
learner_config=learner_config.SerializeToString(),
|
||||
dropout_seed=123,
|
||||
center_bias=True,
|
||||
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
|
||||
session.run(grow_op)
|
||||
|
||||
new_stamp, serialized = session.run(
|
||||
model_ops.tree_ensemble_serialize(tree_ensemble_handle))
|
||||
stats = session.run(
|
||||
training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
|
||||
tree_ensemble_config.ParseFromString(serialized)
|
||||
expected_result = """
|
||||
trees {
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 4
|
||||
threshold: 7
|
||||
}
|
||||
node_metadata {
|
||||
gain: 7.62
|
||||
original_oblivious_leaves {
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 1
|
||||
threshold: 0.23
|
||||
}
|
||||
node_metadata {
|
||||
gain: 2.7
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 7.143
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -4.375
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
oblivious_dense_float_binary_split {
|
||||
feature_column: 0
|
||||
threshold: 0.9
|
||||
}
|
||||
node_metadata {
|
||||
gain: 1.8
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 6.543
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: 7.5
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -4.075
|
||||
}
|
||||
}
|
||||
original_oblivious_leaves {
|
||||
vector {
|
||||
value: -3.975
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 6.543
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 6.543
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 8.5
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 9.5
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -1.075
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -0.075
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -3.975
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: -3.975
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_weights: 0.1
|
||||
tree_metadata {
|
||||
num_tree_weight_updates: 1
|
||||
num_layers_grown: 3
|
||||
}
|
||||
growing_metadata {
|
||||
num_trees_attempted: 1
|
||||
num_layers_attempted: 3
|
||||
}
|
||||
"""
|
||||
self.assertEqual(new_stamp, 1)
|
||||
self.assertEqual(stats.num_trees, 0)
|
||||
self.assertEqual(stats.num_layers, 3)
|
||||
self.assertEqual(stats.active_tree, 1)
|
||||
self.assertEqual(stats.active_layer, 3)
|
||||
self.assertEqual(stats.attempted_trees, 1)
|
||||
self.assertEqual(stats.attempted_layers, 3)
|
||||
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
||||
|
||||
def testGrowExistingEnsembleTreeFinalizedWithDropout(self):
|
||||
"""Test growing an existing ensemble with the last tree finalized."""
|
||||
with self.cached_session() as session:
|
||||
|
Loading…
Reference in New Issue
Block a user