1) Update the proto files for oblivious trees.

2) Grow a new layer of an oblivious tree.

PiperOrigin-RevId: 209633300
This commit is contained in:
A. Unique TensorFlower 2018-08-21 11:47:22 -07:00 committed by TensorFlower Gardener
parent e787c15ae8
commit e28f9da84b
10 changed files with 583 additions and 49 deletions

View File

@ -445,6 +445,7 @@ tf_kernel_library(
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
"//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_headers_lib",

View File

@ -383,19 +383,20 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
best_gain -= num_elements * state->tree_complexity_regularization(); best_gain -= num_elements * state->tree_complexity_regularization();
ObliviousSplitInfo oblivious_split_info; ObliviousSplitInfo oblivious_split_info;
auto* oblivious_dense_split = oblivious_split_info.mutable_split_node() auto* oblivious_dense_split =
->mutable_dense_float_binary_split(); oblivious_split_info.mutable_split_node()
->mutable_oblivious_dense_float_binary_split();
oblivious_dense_split->set_feature_column(state->feature_column_group_id()); oblivious_dense_split->set_feature_column(state->feature_column_group_id());
oblivious_dense_split->set_threshold( oblivious_dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0))); bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
(*gains)(0) = best_gain; (*gains)(0) = best_gain;
for (int root_idx = 0; root_idx < num_elements; root_idx++) { for (int root_idx = 0; root_idx < num_elements; root_idx++) {
auto* left_children = oblivious_split_info.add_children_leaves(); auto* left_child = oblivious_split_info.add_children();
auto* right_children = oblivious_split_info.add_children_leaves(); auto* right_child = oblivious_split_info.add_children();
state->FillLeaf(best_left_node_stats[root_idx], left_children); state->FillLeaf(best_left_node_stats[root_idx], left_child);
state->FillLeaf(best_right_node_stats[root_idx], right_children); state->FillLeaf(best_right_node_stats[root_idx], right_child);
const int start_index = partition_boundaries[root_idx]; const int start_index = partition_boundaries[root_idx];
(*output_partition_ids)(root_idx) = partition_ids(start_index); (*output_partition_ids)(root_idx) = partition_ids(start_index);

View File

@ -15,6 +15,7 @@
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h" #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
@ -26,6 +27,7 @@ namespace boosted_trees {
namespace { namespace {
using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearningRateConfig; using boosted_trees::learner::LearningRateConfig;
using boosted_trees::trees::Leaf; using boosted_trees::trees::Leaf;
using boosted_trees::trees::TreeNode; using boosted_trees::trees::TreeNode;
@ -42,6 +44,9 @@ struct SplitCandidate {
// Split info. // Split info.
learner::SplitInfo split_info; learner::SplitInfo split_info;
// Oblivious split info.
learner::ObliviousSplitInfo oblivious_split_info;
}; };
// Checks that the leaf is not empty. // Checks that the leaf is not empty.
@ -343,7 +348,12 @@ class GrowTreeEnsembleOp : public OpKernel {
OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t)); OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
float learning_rate = learning_rate_t->scalar<float>()(); float learning_rate = learning_rate_t->scalar<float>()();
// Read seed that was used for dropout. // Read the weak learner type to use.
const Tensor* weak_learner_type_t;
OP_REQUIRES_OK(context,
context->input("weak_learner_type", &weak_learner_type_t));
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
const Tensor* seed_t; const Tensor* seed_t;
OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t)); OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
// Cast seed to uint64. // Cast seed to uint64.
@ -363,9 +373,18 @@ class GrowTreeEnsembleOp : public OpKernel {
// Find best splits for each active partition. // Find best splits for each active partition.
std::map<int32, SplitCandidate> best_splits; std::map<int32, SplitCandidate> best_splits;
FindBestSplitsPerPartition(context, partition_ids_list, gains_list, switch (weak_learner_type) {
splits_list, &best_splits); case LearnerConfig::NORMAL_DECISION_TREE: {
FindBestSplitsPerPartitionNormal(context, partition_ids_list,
gains_list, splits_list, &best_splits);
break;
}
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
FindBestSplitsPerPartitionOblivious(context, gains_list, splits_list,
&best_splits);
break;
}
}
// No-op if no new splits can be considered. // No-op if no new splits can be considered.
if (best_splits.empty()) { if (best_splits.empty()) {
LOG(WARNING) << "Not growing tree ensemble as no good splits were found."; LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
@ -377,25 +396,34 @@ class GrowTreeEnsembleOp : public OpKernel {
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->input("max_tree_depth", &max_tree_depth_t)); context->input("max_tree_depth", &max_tree_depth_t));
const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()(); const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
// Update and retrieve the growable tree. // Update and retrieve the growable tree.
// If the tree is fully built and dropout was applied, it also adjusts the // If the tree is fully built and dropout was applied, it also adjusts the
// weights of dropped and the last tree. // weights of dropped and the last tree.
boosted_trees::trees::DecisionTreeConfig* const tree_config = boosted_trees::trees::DecisionTreeConfig* const tree_config =
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
dropout_seed, max_tree_depth); dropout_seed, max_tree_depth,
weak_learner_type);
// Split tree nodes. // Split tree nodes.
switch (weak_learner_type) {
case LearnerConfig::NORMAL_DECISION_TREE: {
for (auto& split_entry : best_splits) { for (auto& split_entry : best_splits) {
SplitTreeNode(split_entry.first, &split_entry.second, tree_config, SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
ensemble_resource); ensemble_resource);
} }
break;
}
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
}
}
// Post-prune finalized tree if needed. // Post-prune finalized tree if needed.
if (learner_config_.pruning_mode() == if (learner_config_.pruning_mode() ==
boosted_trees::learner::LearnerConfig::POST_PRUNE && boosted_trees::learner::LearnerConfig::POST_PRUNE &&
ensemble_resource->LastTreeMetadata()->is_finalized()) { ensemble_resource->LastTreeMetadata()->is_finalized()) {
VLOG(2) << "Post-pruning finalized tree."; VLOG(2) << "Post-pruning finalized tree.";
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
}
PruneTree(tree_config); PruneTree(tree_config);
// If after post-pruning the whole tree has no gain, remove the tree // If after post-pruning the whole tree has no gain, remove the tree
@ -409,10 +437,9 @@ class GrowTreeEnsembleOp : public OpKernel {
private: private:
// Helper method which effectively does a reduce over all split candidates // Helper method which effectively does a reduce over all split candidates
// and finds the best split for each partition. // and finds the best split for each partition.
void FindBestSplitsPerPartition( void FindBestSplitsPerPartitionNormal(
OpKernelContext* const context, OpKernelContext* const context, const OpInputList& partition_ids_list,
const OpInputList& partition_ids_list, const OpInputList& gains_list, const OpInputList& gains_list, const OpInputList& splits_list,
const OpInputList& splits_list,
std::map<int32, SplitCandidate>* best_splits) { std::map<int32, SplitCandidate>* best_splits) {
// Find best split per partition going through every feature candidate. // Find best split per partition going through every feature candidate.
// TODO(salehay): Is this worth parallelizing? // TODO(salehay): Is this worth parallelizing?
@ -446,6 +473,90 @@ class GrowTreeEnsembleOp : public OpKernel {
} }
} }
void FindBestSplitsPerPartitionOblivious(
OpKernelContext* const context, const OpInputList& gains_list,
const OpInputList& splits_list,
std::map<int32, SplitCandidate>* best_splits) {
// Find best split per partition going through every feature candidate.
for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
const auto& gains = gains_list[handler_id].vec<float>();
const auto& splits = splits_list[handler_id].vec<string>();
OP_REQUIRES(context, gains.size() == 1,
errors::InvalidArgument(
"Gains size must be one for oblivious weak learner: ",
gains.size(), " != ", 1));
OP_REQUIRES(context, splits.size() == 1,
errors::InvalidArgument(
"Splits size must be one for oblivious weak learner: ",
splits.size(), " != ", 1));
// Get current split candidate.
const auto& gain = gains(0);
const auto& serialized_split = splits(0);
SplitCandidate split;
split.handler_id = handler_id;
split.gain = gain;
OP_REQUIRES(
context, split.oblivious_split_info.ParseFromString(serialized_split),
errors::InvalidArgument("Unable to parse oblivious split info."));
auto split_info = split.oblivious_split_info;
CHECK(split_info.children_size() % 2 == 0)
<< "The oblivious split should generate an even number of children: "
<< split_info.children_size();
// If every node is pure, then we shouldn't split.
bool only_pure_nodes = true;
for (int idx = 0; idx < split_info.children_size(); idx += 2) {
if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
only_pure_nodes = false;
break;
}
}
if (only_pure_nodes) {
VLOG(1) << "The oblivious split does not actually split anything.";
continue;
}
// Don't consider negative splits if we're pre-pruning the tree.
if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
gain < 0) {
continue;
}
// Take the split if we don't have a candidate yet.
auto best_split_it = best_splits->find(0);
if (best_split_it == best_splits->end()) {
best_splits->insert(std::make_pair(0, std::move(split)));
continue;
}
// Determine if we should update best split.
SplitCandidate& best_split = best_split_it->second;
trees::TreeNode current_node = split_info.split_node();
trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
if (TF_PREDICT_FALSE(gain == best_split.gain)) {
// Tie break on node case preferring simpler tree node types.
VLOG(2) << "Attempting to tie break with smaller node case. "
<< "(current split: " << current_node.node_case()
<< ", best split: " << best_node.node_case() << ")";
if (current_node.node_case() < best_node.node_case()) {
best_split = std::move(split);
} else if (current_node.node_case() == best_node.node_case()) {
// Tie break on handler Id.
VLOG(2) << "Tie breaking with higher handler Id. "
<< "(current split: " << handler_id
<< ", best split: " << best_split.handler_id << ")";
if (handler_id > best_split.handler_id) {
best_split = std::move(split);
}
}
} else if (gain > best_split.gain) {
best_split = std::move(split);
}
}
}
void UpdateTreeWeightsIfDropout( void UpdateTreeWeightsIfDropout(
boosted_trees::models::DecisionTreeEnsembleResource* const boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource, ensemble_resource,
@ -501,7 +612,7 @@ class GrowTreeEnsembleOp : public OpKernel {
boosted_trees::models::DecisionTreeEnsembleResource* const boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource, ensemble_resource,
const float learning_rate, const uint64 dropout_seed, const float learning_rate, const uint64 dropout_seed,
const int32 max_tree_depth) { const int32 max_tree_depth, const int32 weak_learner_type) {
const auto num_trees = ensemble_resource->num_trees(); const auto num_trees = ensemble_resource->num_trees();
if (num_trees <= 0 || if (num_trees <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized()) { ensemble_resource->LastTreeMetadata()->is_finalized()) {
@ -647,6 +758,60 @@ class GrowTreeEnsembleOp : public OpKernel {
} }
} }
void SplitTreeLayer(
SplitCandidate* split,
boosted_trees::trees::DecisionTreeConfig* tree_config,
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
int depth = 0;
while (depth < tree_config->nodes_size() &&
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
depth++;
}
CHECK(tree_config->nodes_size() > 0)
<< "A tree must have at least one dummy leaf.";
// The number of new children.
int num_children = 1 << (depth + 1);
auto split_info = split->oblivious_split_info;
CHECK(num_children == split_info.children_size())
<< "Wrong number of new children: " << num_children
<< " != " << split_info.children_size();
for (int idx = 0; idx < num_children; idx += 2) {
// Old leaf is at position depth + idx / 2.
trees::Leaf old_leaf =
*tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf();
// Update left leaf.
*split_info.mutable_children(idx) =
*MergeLeafWeights(old_leaf, split_info.mutable_children(idx));
// Update right leaf.
*split_info.mutable_children(idx + 1) =
*MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1));
}
TreeNodeMetadata* split_metadata =
split_info.mutable_split_node()->mutable_node_metadata();
split_metadata->set_gain(split->gain);
TreeNode new_split = *split_info.mutable_split_node();
// Move old children to metadata.
for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
*new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
*tree_config->mutable_nodes(idx)->mutable_leaf();
}
// Add the new split to the tree_config in place before the children start.
*tree_config->mutable_nodes(depth) = new_split;
// Add the new children
int nodes_size = tree_config->nodes_size();
for (int idx = 0; idx < num_children; idx++) {
if (idx + depth + 1 < nodes_size) {
// Update leaves that were already there.
*tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
*split_info.mutable_children(idx);
} else {
// Add new leaves.
*tree_config->add_nodes()->mutable_leaf() =
*split_info.mutable_children(idx);
}
}
}
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) { void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
// No-op if tree is empty. // No-op if tree is empty.
if (tree_config->nodes_size() <= 0) { if (tree_config->nodes_size() <= 0) {

View File

@ -258,8 +258,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
oblivious_split_info = split_info_pb2.ObliviousSplitInfo() oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
oblivious_split_info.ParseFromString(splits[0]) oblivious_split_info.ParseFromString(splits[0])
split_node = oblivious_split_info.split_node.dense_float_binary_split split_node = oblivious_split_info.split_node
split_node = split_node.oblivious_dense_float_binary_split
self.assertAllClose(0.3, split_node.threshold, 0.00001) self.assertAllClose(0.3, split_node.threshold, 0.00001)
self.assertEqual(0, split_node.feature_column) self.assertEqual(0, split_node.feature_column)
@ -279,8 +279,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
expected_bias_gain_0 = 0.46043165467625896 expected_bias_gain_0 = 0.46043165467625896
left_child = oblivious_split_info.children_leaves[0].vector left_child = oblivious_split_info.children[0].vector
right_child = oblivious_split_info.children_leaves[1].vector right_child = oblivious_split_info.children[1].vector
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
@ -296,8 +296,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
# (-4 + 0.1) ** 2 / (0.13 + 1) # (-4 + 0.1) ** 2 / (0.13 + 1)
expected_bias_gain_1 = 13.460176991150442 expected_bias_gain_1 = 13.460176991150442
left_child = oblivious_split_info.children_leaves[2].vector left_child = oblivious_split_info.children[2].vector
right_child = oblivious_split_info.children_leaves[3].vector right_child = oblivious_split_info.children[3].vector
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#include <algorithm>
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include <algorithm>
namespace tensorflow { namespace tensorflow {
namespace boosted_trees { namespace boosted_trees {
namespace trees { namespace trees {
@ -28,14 +28,15 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) { if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
return kInvalidLeaf; return kInvalidLeaf;
} }
// Traverse tree starting at the provided sub-root. // Traverse tree starting at the provided sub-root.
int32 node_id = sub_root_id; int32 node_id = sub_root_id;
// The index of the leave that holds this example in the oblivious case.
int oblivious_leaf_idx = 0;
while (true) { while (true) {
const auto& current_node = config.nodes(node_id); const auto& current_node = config.nodes(node_id);
switch (current_node.node_case()) { switch (current_node.node_case()) {
case TreeNode::kLeaf: { case TreeNode::kLeaf: {
return node_id; return node_id + oblivious_leaf_idx;
} }
case TreeNode::kDenseFloatBinarySplit: { case TreeNode::kDenseFloatBinarySplit: {
const auto& split = current_node.dense_float_binary_split(); const auto& split = current_node.dense_float_binary_split();
@ -100,6 +101,16 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
} }
break; break;
} }
case TreeNode::kObliviousDenseFloatBinarySplit: {
const auto& split = current_node.oblivious_dense_float_binary_split();
oblivious_leaf_idx <<= 1;
if (example.dense_float_features[split.feature_column()] >
split.threshold()) {
oblivious_leaf_idx++;
}
node_id++;
break;
}
case TreeNode::NODE_NOT_SET: { case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString(); LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
break; break;
@ -165,6 +176,11 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children,
split->set_right_id(*++children_it); split->set_right_id(*++children_it);
break; break;
} }
case TreeNode::kObliviousDenseFloatBinarySplit: {
LOG(QFATAL)
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
break;
}
case TreeNode::NODE_NOT_SET: { case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "A non-set node cannot have children."; LOG(QFATAL) << "A non-set node cannot have children.";
break; break;
@ -199,6 +215,11 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
const auto& split = node.categorical_id_set_membership_binary_split(); const auto& split = node.categorical_id_set_membership_binary_split();
return {split.left_id(), split.right_id()}; return {split.left_id(), split.right_id()};
} }
case TreeNode::kObliviousDenseFloatBinarySplit: {
LOG(QFATAL)
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
return {};
}
case TreeNode::NODE_NOT_SET: { case TreeNode::NODE_NOT_SET: {
return {}; return {};
} }

View File

@ -57,6 +57,7 @@ REGISTER_OP("GrowTreeEnsemble")
.Input("learning_rate: float") .Input("learning_rate: float")
.Input("dropout_seed: int64") .Input("dropout_seed: int64")
.Input("max_tree_depth: int32") .Input("max_tree_depth: int32")
.Input("weak_learner_type: int32")
.Input("partition_ids: num_handlers * int32") .Input("partition_ids: num_handlers * int32")
.Input("gains: num_handlers * float") .Input("gains: num_handlers * float")
.Input("splits: num_handlers * string") .Input("splits: num_handlers * string")
@ -82,6 +83,7 @@ tree_ensemble_handle: Handle to the ensemble variable.
stamp_token: Stamp token for validating operation consistency. stamp_token: Stamp token for validating operation consistency.
next_stamp_token: Stamp token to be used for the next iteration. next_stamp_token: Stamp token to be used for the next iteration.
learning_rate: Scalar learning rate. learning_rate: Scalar learning rate.
weak_learner_type: The type of weak learner to use.
partition_ids: List of Rank 1 Tensors containing partition Id per candidate. partition_ids: List of Rank 1 Tensors containing partition Id per candidate.
gains: List of Rank 1 Tensors containing gains per candidate. gains: List of Rank 1 Tensors containing gains per candidate.
splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate. splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate.

View File

@ -19,8 +19,6 @@ message SplitInfo {
} }
message ObliviousSplitInfo { message ObliviousSplitInfo {
// The split node with the feature_column and threshold defined.
tensorflow.boosted_trees.trees.TreeNode split_node = 1; tensorflow.boosted_trees.trees.TreeNode split_node = 1;
// The new leaves of the tree. repeated tensorflow.boosted_trees.trees.Leaf children = 2;
repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
} }

View File

@ -15,6 +15,7 @@ message TreeNode {
CategoricalIdBinarySplit categorical_id_binary_split = 5; CategoricalIdBinarySplit categorical_id_binary_split = 5;
CategoricalIdSetMembershipBinarySplit CategoricalIdSetMembershipBinarySplit
categorical_id_set_membership_binary_split = 6; categorical_id_set_membership_binary_split = 6;
ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
} }
TreeNodeMetadata node_metadata = 777; TreeNodeMetadata node_metadata = 777;
} }
@ -26,6 +27,9 @@ message TreeNodeMetadata {
// The original leaf node before this node was split. // The original leaf node before this node was split.
Leaf original_leaf = 2; Leaf original_leaf = 2;
// The original layer of leaves before that layer was converted to a split.
repeated Leaf original_oblivious_leaves = 3;
} }
// Leaves can either hold dense or sparse information. // Leaves can either hold dense or sparse information.
@ -101,6 +105,17 @@ message CategoricalIdSetMembershipBinarySplit {
int32 right_id = 4; int32 right_id = 4;
} }
// Split rule for dense float features in the oblivious case.
message ObliviousDenseFloatBinarySplit {
// Float feature column and split threshold describing
// the rule feature <= threshold.
int32 feature_column = 1;
float threshold = 2;
// We don't store children ids, because either the next node represents the
// whole next layer of the tree or starting with the next node we only have
// leaves.
}
// DecisionTreeConfig describes a list of connected nodes. // DecisionTreeConfig describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf // Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias. // in the case of representing the bias.

View File

@ -91,6 +91,27 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
return split.SerializeToString() return split.SerializeToString()
def _gen_dense_oblivious_split_info(fc, threshold, leave_weights):
split_str = """
split_node {
oblivious_dense_float_binary_split {
feature_column: %d
threshold: %f
}
}""" % (fc, threshold)
for weight in leave_weights:
split_str += """
children {
vector {
value: %f
}
}""" % (
weight)
split = split_info_pb2.ObliviousSplitInfo()
text_format.Merge(split_str, split)
return split.SerializeToString()
def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight): def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight):
split_str = """ split_str = """
split_node { split_node {
@ -324,7 +345,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen. # Expect the simpler split from handler 1 to be chosen.
@ -383,6 +405,115 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(stats.attempted_layers, 1) self.assertEqual(stats.attempted_layers, 1)
self.assertProtoEquals(expected_result, tree_ensemble_config) self.assertProtoEquals(expected_result, tree_ensemble_config)
def testGrowEmptyEnsembleObliviousCase(self):
"""Test growing an empty ensemble in the oblivious case."""
with self.test_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0,
tree_ensemble_config=tree_ensemble_config.SerializeToString(),
name="tree_ensemble")
resources.initialize_resources(resources.shared_resources()).run()
# Prepare learner config.
learner_config = _gen_learner_config(
num_classes=2,
l1_reg=0,
l2_reg=0,
tree_complexity=0,
max_depth=1,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
# Prepare handler inputs.
# Note that handlers 1 & 3 have the same gain but different splits.
handler1_partitions = np.array([0], dtype=np.int32)
handler1_gains = np.array([7.62], dtype=np.float32)
handler1_split = [
_gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143])
]
handler2_partitions = np.array([0], dtype=np.int32)
handler2_gains = np.array([0.63], dtype=np.float32)
handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])]
handler3_partitions = np.array([0], dtype=np.int32)
handler3_gains = np.array([7.62], dtype=np.float32)
handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])]
# Grow tree ensemble.
grow_op = training_ops.grow_tree_ensemble(
tree_ensemble_handle,
stamp_token=0,
next_stamp_token=1,
learning_rate=0.1,
partition_ids=[
handler1_partitions, handler2_partitions, handler3_partitions
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
session.run(grow_op)
# Expect the split with bigger handler_id, i.e. handler 3 to be chosen.
# The grown tree should be finalized as max tree depth is 1.
new_stamp, serialized = session.run(
model_ops.tree_ensemble_serialize(tree_ensemble_handle))
stats = session.run(
training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
tree_ensemble_config.ParseFromString(serialized)
expected_result = """
trees {
nodes {
oblivious_dense_float_binary_split {
feature_column: 0
threshold: 7
}
node_metadata {
gain: 7.62
original_oblivious_leaves {
}
}
}
nodes {
leaf {
vector {
value: -4.375
}
}
}
nodes {
leaf {
vector {
value: 7.143
}
}
}
}
tree_weights: 0.1
tree_metadata {
num_tree_weight_updates: 1
num_layers_grown: 1
is_finalized: true
}
growing_metadata {
num_trees_attempted: 1
num_layers_attempted: 1
}
"""
self.assertEqual(new_stamp, 1)
self.assertEqual(stats.num_trees, 1)
self.assertEqual(stats.num_layers, 1)
self.assertEqual(stats.active_tree, 1)
self.assertEqual(stats.active_layer, 1)
self.assertEqual(stats.attempted_trees, 1)
self.assertEqual(stats.attempted_layers, 1)
self.assertProtoEquals(expected_result, tree_ensemble_config)
def testGrowExistingEnsembleTreeNotFinalized(self): def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized.""" """Test growing an existing ensemble with the last tree not finalized."""
with self.test_session() as session: with self.test_session() as session:
@ -476,7 +607,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and # Expect the split for partition 1 to be chosen from handler 1 and
@ -661,7 +793,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect a new tree to be added with the split from handler 1. # Expect a new tree to be added with the split from handler 1.
@ -798,7 +931,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the ensemble to be empty. # Expect the ensemble to be empty.
@ -869,7 +1003,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen. # Expect the simpler split from handler 1 to be chosen.
@ -971,7 +1106,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain. # Expect the split from handler 2 to be chosen despite the negative gain.
@ -1053,7 +1189,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the ensemble to be empty as post-pruning will prune # Expect the ensemble to be empty as post-pruning will prune
@ -1120,7 +1257,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain. # Expect the split from handler 2 to be chosen despite the negative gain.
@ -1200,7 +1338,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the negative gain split of partition 1 to be pruned and the # Expect the negative gain split of partition 1 to be pruned and the
@ -1371,7 +1510,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and # Expect the split for partition 1 to be chosen from handler 1 and
@ -1470,6 +1610,193 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(stats.attempted_layers, 2) self.assertEqual(stats.attempted_layers, 2)
self.assertProtoEquals(expected_result, tree_ensemble_config) self.assertProtoEquals(expected_result, tree_ensemble_config)
def testGrowEnsembleTreeLayerByLayerObliviousCase(self):
"""Test growing an existing ensemble with the last tree not finalized."""
with self.test_session() as session:
# Create existing ensemble with one root split
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
"""
trees {
nodes {
oblivious_dense_float_binary_split {
feature_column: 4
threshold: 7
}
node_metadata {
gain: 7.62
original_oblivious_leaves {
}
}
}
nodes {
leaf {
vector {
value: 7.143
}
}
}
nodes {
leaf {
vector {
value: -4.375
}
}
}
}
tree_weights: 0.1
tree_metadata {
num_tree_weight_updates: 1
num_layers_grown: 1
}
growing_metadata {
num_trees_attempted: 1
num_layers_attempted: 1
}
""", tree_ensemble_config)
tree_ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0,
tree_ensemble_config=tree_ensemble_config.SerializeToString(),
name="tree_ensemble")
resources.initialize_resources(resources.shared_resources()).run()
# Prepare learner config.
learner_config = _gen_learner_config(
num_classes=2,
l1_reg=0,
l2_reg=0,
tree_complexity=0,
max_depth=3,
min_node_weight=0,
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
# Prepare handler inputs.
handler1_partitions = np.array([0], dtype=np.int32)
handler1_gains = np.array([1.4], dtype=np.float32)
handler1_split = [
_gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5])
]
handler2_partitions = np.array([0], dtype=np.int32)
handler2_gains = np.array([2.7], dtype=np.float32)
handler2_split = [
_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]),
]
handler3_partitions = np.array([0], dtype=np.int32)
handler3_gains = np.array([1.7], dtype=np.float32)
handler3_split = [
_gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1])
]
# Grow tree ensemble layer by layer.
grow_op = training_ops.grow_tree_ensemble(
tree_ensemble_handle,
stamp_token=0,
next_stamp_token=1,
learning_rate=0.1,
partition_ids=[
handler1_partitions, handler2_partitions, handler3_partitions
],
gains=[handler1_gains, handler2_gains, handler3_gains],
splits=[handler1_split, handler2_split, handler3_split],
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and
# the split for partition 2 to be chosen from handler 2.
# The grown tree should not be finalized as max tree depth is 3 and
# it's only grown 2 layers.
# The partition 1 split weights get added to original leaf weight 7.143.
# The partition 2 split weights get added to original leaf weight -4.375.
new_stamp, serialized = session.run(
model_ops.tree_ensemble_serialize(tree_ensemble_handle))
stats = session.run(
training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
tree_ensemble_config.ParseFromString(serialized)
expected_result = """
trees {
nodes {
oblivious_dense_float_binary_split {
feature_column: 4
threshold: 7
}
node_metadata {
gain: 7.62
original_oblivious_leaves {
}
}
}
nodes {
oblivious_dense_float_binary_split {
feature_column: 0
threshold: 0.23
}
node_metadata {
gain: 2.7
original_oblivious_leaves {
vector {
value: 7.143
}
}
original_oblivious_leaves {
vector {
value: -4.375
}
}
}
}
nodes {
leaf {
vector {
value: 6.543
}
}
}
nodes {
leaf {
vector {
value: 7.383
}
}
}
nodes {
leaf {
vector {
value: -4.075
}
}
}
nodes {
leaf {
vector {
value: -3.975
}
}
}
}
tree_weights: 0.1
tree_metadata {
num_tree_weight_updates: 1
num_layers_grown: 2
}
growing_metadata {
num_trees_attempted: 1
num_layers_attempted: 2
}
"""
self.assertEqual(new_stamp, 1)
self.assertEqual(stats.num_trees, 0)
self.assertEqual(stats.num_layers, 2)
self.assertEqual(stats.active_tree, 1)
self.assertEqual(stats.active_layer, 2)
self.assertEqual(stats.attempted_trees, 1)
self.assertEqual(stats.attempted_layers, 2)
self.assertProtoEquals(expected_result, tree_ensemble_config)
def testGrowExistingEnsembleTreeFinalizedWithDropout(self): def testGrowExistingEnsembleTreeFinalizedWithDropout(self):
"""Test growing an existing ensemble with the last tree finalized.""" """Test growing an existing ensemble with the last tree finalized."""
with self.test_session() as session: with self.test_session() as session:
@ -1575,7 +1902,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
# Expect a new tree to be added with the split from handler 1. # Expect a new tree to be added with the split from handler 1.
@ -1700,7 +2028,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(), learner_config=learner_config.SerializeToString(),
dropout_seed=123, dropout_seed=123,
center_bias=True, center_bias=True,
max_tree_depth=learner_config.constraints.max_tree_depth) max_tree_depth=learner_config.constraints.max_tree_depth,
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op) session.run(grow_op)
_, serialized = session.run( _, serialized = session.run(

View File

@ -1076,7 +1076,8 @@ class GradientBoostedDecisionTreeModel(object):
learner_config=self._learner_config_serialized, learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed, dropout_seed=dropout_seed,
center_bias=self._center_bias, center_bias=self._center_bias,
max_tree_depth=self._max_tree_depth) max_tree_depth=self._max_tree_depth,
weak_learner_type=self._learner_config.weak_learner_type)
def _grow_ensemble_not_ready_fn(): def _grow_ensemble_not_ready_fn():
# Don't grow the ensemble, just update the stamp. # Don't grow the ensemble, just update the stamp.
@ -1091,7 +1092,8 @@ class GradientBoostedDecisionTreeModel(object):
learner_config=self._learner_config_serialized, learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed, dropout_seed=dropout_seed,
center_bias=self._center_bias, center_bias=self._center_bias,
max_tree_depth=self._max_tree_depth) max_tree_depth=self._max_tree_depth,
weak_learner_type=self._learner_config.weak_learner_type)
def _grow_ensemble_fn(): def _grow_ensemble_fn():
# Conditionally grow an ensemble depending on whether the splits # Conditionally grow an ensemble depending on whether the splits