1) Update the proto files for oblivious trees.
2) Grow a new layer of an oblivious tree. PiperOrigin-RevId: 209633300
This commit is contained in:
parent
e787c15ae8
commit
e28f9da84b
@ -445,6 +445,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
|
||||||
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
|
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
|
||||||
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
|
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
|
||||||
|
"//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
|
||||||
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
|
||||||
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
|
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
|
||||||
"//tensorflow/core:framework_headers_lib",
|
"//tensorflow/core:framework_headers_lib",
|
||||||
|
@ -383,19 +383,20 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
|
|||||||
best_gain -= num_elements * state->tree_complexity_regularization();
|
best_gain -= num_elements * state->tree_complexity_regularization();
|
||||||
|
|
||||||
ObliviousSplitInfo oblivious_split_info;
|
ObliviousSplitInfo oblivious_split_info;
|
||||||
auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
|
auto* oblivious_dense_split =
|
||||||
->mutable_dense_float_binary_split();
|
oblivious_split_info.mutable_split_node()
|
||||||
|
->mutable_oblivious_dense_float_binary_split();
|
||||||
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
|
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
|
||||||
oblivious_dense_split->set_threshold(
|
oblivious_dense_split->set_threshold(
|
||||||
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
|
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
|
||||||
(*gains)(0) = best_gain;
|
(*gains)(0) = best_gain;
|
||||||
|
|
||||||
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
|
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
|
||||||
auto* left_children = oblivious_split_info.add_children_leaves();
|
auto* left_child = oblivious_split_info.add_children();
|
||||||
auto* right_children = oblivious_split_info.add_children_leaves();
|
auto* right_child = oblivious_split_info.add_children();
|
||||||
|
|
||||||
state->FillLeaf(best_left_node_stats[root_idx], left_children);
|
state->FillLeaf(best_left_node_stats[root_idx], left_child);
|
||||||
state->FillLeaf(best_right_node_stats[root_idx], right_children);
|
state->FillLeaf(best_right_node_stats[root_idx], right_child);
|
||||||
|
|
||||||
const int start_index = partition_boundaries[root_idx];
|
const int start_index = partition_boundaries[root_idx];
|
||||||
(*output_partition_ids)(root_idx) = partition_ids(start_index);
|
(*output_partition_ids)(root_idx) = partition_ids(start_index);
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
|
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
|
||||||
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
|
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
|
||||||
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
|
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
|
||||||
|
#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
|
||||||
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
@ -26,6 +27,7 @@ namespace boosted_trees {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using boosted_trees::learner::LearnerConfig;
|
||||||
using boosted_trees::learner::LearningRateConfig;
|
using boosted_trees::learner::LearningRateConfig;
|
||||||
using boosted_trees::trees::Leaf;
|
using boosted_trees::trees::Leaf;
|
||||||
using boosted_trees::trees::TreeNode;
|
using boosted_trees::trees::TreeNode;
|
||||||
@ -42,6 +44,9 @@ struct SplitCandidate {
|
|||||||
|
|
||||||
// Split info.
|
// Split info.
|
||||||
learner::SplitInfo split_info;
|
learner::SplitInfo split_info;
|
||||||
|
|
||||||
|
// Oblivious split info.
|
||||||
|
learner::ObliviousSplitInfo oblivious_split_info;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Checks that the leaf is not empty.
|
// Checks that the leaf is not empty.
|
||||||
@ -343,7 +348,12 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
|
OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
|
||||||
float learning_rate = learning_rate_t->scalar<float>()();
|
float learning_rate = learning_rate_t->scalar<float>()();
|
||||||
|
|
||||||
// Read seed that was used for dropout.
|
// Read the weak learner type to use.
|
||||||
|
const Tensor* weak_learner_type_t;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->input("weak_learner_type", &weak_learner_type_t));
|
||||||
|
const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
|
||||||
|
|
||||||
const Tensor* seed_t;
|
const Tensor* seed_t;
|
||||||
OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
|
OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
|
||||||
// Cast seed to uint64.
|
// Cast seed to uint64.
|
||||||
@ -363,9 +373,18 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
|
|
||||||
// Find best splits for each active partition.
|
// Find best splits for each active partition.
|
||||||
std::map<int32, SplitCandidate> best_splits;
|
std::map<int32, SplitCandidate> best_splits;
|
||||||
FindBestSplitsPerPartition(context, partition_ids_list, gains_list,
|
switch (weak_learner_type) {
|
||||||
splits_list, &best_splits);
|
case LearnerConfig::NORMAL_DECISION_TREE: {
|
||||||
|
FindBestSplitsPerPartitionNormal(context, partition_ids_list,
|
||||||
|
gains_list, splits_list, &best_splits);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
|
||||||
|
FindBestSplitsPerPartitionOblivious(context, gains_list, splits_list,
|
||||||
|
&best_splits);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
// No-op if no new splits can be considered.
|
// No-op if no new splits can be considered.
|
||||||
if (best_splits.empty()) {
|
if (best_splits.empty()) {
|
||||||
LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
|
LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
|
||||||
@ -377,25 +396,34 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->input("max_tree_depth", &max_tree_depth_t));
|
context->input("max_tree_depth", &max_tree_depth_t));
|
||||||
const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
|
const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
|
||||||
|
|
||||||
// Update and retrieve the growable tree.
|
// Update and retrieve the growable tree.
|
||||||
// If the tree is fully built and dropout was applied, it also adjusts the
|
// If the tree is fully built and dropout was applied, it also adjusts the
|
||||||
// weights of dropped and the last tree.
|
// weights of dropped and the last tree.
|
||||||
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
boosted_trees::trees::DecisionTreeConfig* const tree_config =
|
||||||
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
|
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
|
||||||
dropout_seed, max_tree_depth);
|
dropout_seed, max_tree_depth,
|
||||||
|
weak_learner_type);
|
||||||
// Split tree nodes.
|
// Split tree nodes.
|
||||||
for (auto& split_entry : best_splits) {
|
switch (weak_learner_type) {
|
||||||
SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
|
case LearnerConfig::NORMAL_DECISION_TREE: {
|
||||||
ensemble_resource);
|
for (auto& split_entry : best_splits) {
|
||||||
|
SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
|
||||||
|
ensemble_resource);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
|
||||||
|
SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post-prune finalized tree if needed.
|
// Post-prune finalized tree if needed.
|
||||||
if (learner_config_.pruning_mode() ==
|
if (learner_config_.pruning_mode() ==
|
||||||
boosted_trees::learner::LearnerConfig::POST_PRUNE &&
|
boosted_trees::learner::LearnerConfig::POST_PRUNE &&
|
||||||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
|
ensemble_resource->LastTreeMetadata()->is_finalized()) {
|
||||||
VLOG(2) << "Post-pruning finalized tree.";
|
VLOG(2) << "Post-pruning finalized tree.";
|
||||||
|
if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
|
||||||
|
LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
|
||||||
|
}
|
||||||
PruneTree(tree_config);
|
PruneTree(tree_config);
|
||||||
|
|
||||||
// If after post-pruning the whole tree has no gain, remove the tree
|
// If after post-pruning the whole tree has no gain, remove the tree
|
||||||
@ -409,10 +437,9 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
private:
|
private:
|
||||||
// Helper method which effectively does a reduce over all split candidates
|
// Helper method which effectively does a reduce over all split candidates
|
||||||
// and finds the best split for each partition.
|
// and finds the best split for each partition.
|
||||||
void FindBestSplitsPerPartition(
|
void FindBestSplitsPerPartitionNormal(
|
||||||
OpKernelContext* const context,
|
OpKernelContext* const context, const OpInputList& partition_ids_list,
|
||||||
const OpInputList& partition_ids_list, const OpInputList& gains_list,
|
const OpInputList& gains_list, const OpInputList& splits_list,
|
||||||
const OpInputList& splits_list,
|
|
||||||
std::map<int32, SplitCandidate>* best_splits) {
|
std::map<int32, SplitCandidate>* best_splits) {
|
||||||
// Find best split per partition going through every feature candidate.
|
// Find best split per partition going through every feature candidate.
|
||||||
// TODO(salehay): Is this worth parallelizing?
|
// TODO(salehay): Is this worth parallelizing?
|
||||||
@ -446,6 +473,90 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FindBestSplitsPerPartitionOblivious(
|
||||||
|
OpKernelContext* const context, const OpInputList& gains_list,
|
||||||
|
const OpInputList& splits_list,
|
||||||
|
std::map<int32, SplitCandidate>* best_splits) {
|
||||||
|
// Find best split per partition going through every feature candidate.
|
||||||
|
for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
|
||||||
|
const auto& gains = gains_list[handler_id].vec<float>();
|
||||||
|
const auto& splits = splits_list[handler_id].vec<string>();
|
||||||
|
OP_REQUIRES(context, gains.size() == 1,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Gains size must be one for oblivious weak learner: ",
|
||||||
|
gains.size(), " != ", 1));
|
||||||
|
OP_REQUIRES(context, splits.size() == 1,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Splits size must be one for oblivious weak learner: ",
|
||||||
|
splits.size(), " != ", 1));
|
||||||
|
// Get current split candidate.
|
||||||
|
const auto& gain = gains(0);
|
||||||
|
const auto& serialized_split = splits(0);
|
||||||
|
SplitCandidate split;
|
||||||
|
split.handler_id = handler_id;
|
||||||
|
split.gain = gain;
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, split.oblivious_split_info.ParseFromString(serialized_split),
|
||||||
|
errors::InvalidArgument("Unable to parse oblivious split info."));
|
||||||
|
|
||||||
|
auto split_info = split.oblivious_split_info;
|
||||||
|
CHECK(split_info.children_size() % 2 == 0)
|
||||||
|
<< "The oblivious split should generate an even number of children: "
|
||||||
|
<< split_info.children_size();
|
||||||
|
|
||||||
|
// If every node is pure, then we shouldn't split.
|
||||||
|
bool only_pure_nodes = true;
|
||||||
|
for (int idx = 0; idx < split_info.children_size(); idx += 2) {
|
||||||
|
if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
|
||||||
|
IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
|
||||||
|
only_pure_nodes = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (only_pure_nodes) {
|
||||||
|
VLOG(1) << "The oblivious split does not actually split anything.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't consider negative splits if we're pre-pruning the tree.
|
||||||
|
if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
|
||||||
|
gain < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take the split if we don't have a candidate yet.
|
||||||
|
auto best_split_it = best_splits->find(0);
|
||||||
|
if (best_split_it == best_splits->end()) {
|
||||||
|
best_splits->insert(std::make_pair(0, std::move(split)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine if we should update best split.
|
||||||
|
SplitCandidate& best_split = best_split_it->second;
|
||||||
|
trees::TreeNode current_node = split_info.split_node();
|
||||||
|
trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
|
||||||
|
if (TF_PREDICT_FALSE(gain == best_split.gain)) {
|
||||||
|
// Tie break on node case preferring simpler tree node types.
|
||||||
|
VLOG(2) << "Attempting to tie break with smaller node case. "
|
||||||
|
<< "(current split: " << current_node.node_case()
|
||||||
|
<< ", best split: " << best_node.node_case() << ")";
|
||||||
|
if (current_node.node_case() < best_node.node_case()) {
|
||||||
|
best_split = std::move(split);
|
||||||
|
} else if (current_node.node_case() == best_node.node_case()) {
|
||||||
|
// Tie break on handler Id.
|
||||||
|
VLOG(2) << "Tie breaking with higher handler Id. "
|
||||||
|
<< "(current split: " << handler_id
|
||||||
|
<< ", best split: " << best_split.handler_id << ")";
|
||||||
|
if (handler_id > best_split.handler_id) {
|
||||||
|
best_split = std::move(split);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (gain > best_split.gain) {
|
||||||
|
best_split = std::move(split);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void UpdateTreeWeightsIfDropout(
|
void UpdateTreeWeightsIfDropout(
|
||||||
boosted_trees::models::DecisionTreeEnsembleResource* const
|
boosted_trees::models::DecisionTreeEnsembleResource* const
|
||||||
ensemble_resource,
|
ensemble_resource,
|
||||||
@ -501,7 +612,7 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
boosted_trees::models::DecisionTreeEnsembleResource* const
|
boosted_trees::models::DecisionTreeEnsembleResource* const
|
||||||
ensemble_resource,
|
ensemble_resource,
|
||||||
const float learning_rate, const uint64 dropout_seed,
|
const float learning_rate, const uint64 dropout_seed,
|
||||||
const int32 max_tree_depth) {
|
const int32 max_tree_depth, const int32 weak_learner_type) {
|
||||||
const auto num_trees = ensemble_resource->num_trees();
|
const auto num_trees = ensemble_resource->num_trees();
|
||||||
if (num_trees <= 0 ||
|
if (num_trees <= 0 ||
|
||||||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
|
ensemble_resource->LastTreeMetadata()->is_finalized()) {
|
||||||
@ -647,6 +758,60 @@ class GrowTreeEnsembleOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SplitTreeLayer(
|
||||||
|
SplitCandidate* split,
|
||||||
|
boosted_trees::trees::DecisionTreeConfig* tree_config,
|
||||||
|
boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
|
||||||
|
int depth = 0;
|
||||||
|
while (depth < tree_config->nodes_size() &&
|
||||||
|
tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
|
||||||
|
depth++;
|
||||||
|
}
|
||||||
|
CHECK(tree_config->nodes_size() > 0)
|
||||||
|
<< "A tree must have at least one dummy leaf.";
|
||||||
|
// The number of new children.
|
||||||
|
int num_children = 1 << (depth + 1);
|
||||||
|
auto split_info = split->oblivious_split_info;
|
||||||
|
CHECK(num_children == split_info.children_size())
|
||||||
|
<< "Wrong number of new children: " << num_children
|
||||||
|
<< " != " << split_info.children_size();
|
||||||
|
for (int idx = 0; idx < num_children; idx += 2) {
|
||||||
|
// Old leaf is at position depth + idx / 2.
|
||||||
|
trees::Leaf old_leaf =
|
||||||
|
*tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf();
|
||||||
|
// Update left leaf.
|
||||||
|
*split_info.mutable_children(idx) =
|
||||||
|
*MergeLeafWeights(old_leaf, split_info.mutable_children(idx));
|
||||||
|
// Update right leaf.
|
||||||
|
*split_info.mutable_children(idx + 1) =
|
||||||
|
*MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1));
|
||||||
|
}
|
||||||
|
TreeNodeMetadata* split_metadata =
|
||||||
|
split_info.mutable_split_node()->mutable_node_metadata();
|
||||||
|
split_metadata->set_gain(split->gain);
|
||||||
|
|
||||||
|
TreeNode new_split = *split_info.mutable_split_node();
|
||||||
|
// Move old children to metadata.
|
||||||
|
for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
|
||||||
|
*new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
|
||||||
|
*tree_config->mutable_nodes(idx)->mutable_leaf();
|
||||||
|
}
|
||||||
|
// Add the new split to the tree_config in place before the children start.
|
||||||
|
*tree_config->mutable_nodes(depth) = new_split;
|
||||||
|
// Add the new children
|
||||||
|
int nodes_size = tree_config->nodes_size();
|
||||||
|
for (int idx = 0; idx < num_children; idx++) {
|
||||||
|
if (idx + depth + 1 < nodes_size) {
|
||||||
|
// Update leaves that were already there.
|
||||||
|
*tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
|
||||||
|
*split_info.mutable_children(idx);
|
||||||
|
} else {
|
||||||
|
// Add new leaves.
|
||||||
|
*tree_config->add_nodes()->mutable_leaf() =
|
||||||
|
*split_info.mutable_children(idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
|
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
|
||||||
// No-op if tree is empty.
|
// No-op if tree is empty.
|
||||||
if (tree_config->nodes_size() <= 0) {
|
if (tree_config->nodes_size() <= 0) {
|
||||||
|
@ -258,8 +258,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
|
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
|
||||||
oblivious_split_info.ParseFromString(splits[0])
|
oblivious_split_info.ParseFromString(splits[0])
|
||||||
split_node = oblivious_split_info.split_node.dense_float_binary_split
|
split_node = oblivious_split_info.split_node
|
||||||
|
split_node = split_node.oblivious_dense_float_binary_split
|
||||||
self.assertAllClose(0.3, split_node.threshold, 0.00001)
|
self.assertAllClose(0.3, split_node.threshold, 0.00001)
|
||||||
self.assertEqual(0, split_node.feature_column)
|
self.assertEqual(0, split_node.feature_column)
|
||||||
|
|
||||||
@ -279,8 +279,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
|
|||||||
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
|
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
|
||||||
expected_bias_gain_0 = 0.46043165467625896
|
expected_bias_gain_0 = 0.46043165467625896
|
||||||
|
|
||||||
left_child = oblivious_split_info.children_leaves[0].vector
|
left_child = oblivious_split_info.children[0].vector
|
||||||
right_child = oblivious_split_info.children_leaves[1].vector
|
right_child = oblivious_split_info.children[1].vector
|
||||||
|
|
||||||
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
|
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
|
||||||
|
|
||||||
@ -296,8 +296,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
|
|||||||
# (-4 + 0.1) ** 2 / (0.13 + 1)
|
# (-4 + 0.1) ** 2 / (0.13 + 1)
|
||||||
expected_bias_gain_1 = 13.460176991150442
|
expected_bias_gain_1 = 13.460176991150442
|
||||||
|
|
||||||
left_child = oblivious_split_info.children_leaves[2].vector
|
left_child = oblivious_split_info.children[2].vector
|
||||||
right_child = oblivious_split_info.children_leaves[3].vector
|
right_child = oblivious_split_info.children[3].vector
|
||||||
|
|
||||||
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
|
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
|
||||||
|
|
||||||
|
@ -12,11 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace boosted_trees {
|
namespace boosted_trees {
|
||||||
namespace trees {
|
namespace trees {
|
||||||
@ -28,14 +28,15 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
|
|||||||
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
|
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
|
||||||
return kInvalidLeaf;
|
return kInvalidLeaf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Traverse tree starting at the provided sub-root.
|
// Traverse tree starting at the provided sub-root.
|
||||||
int32 node_id = sub_root_id;
|
int32 node_id = sub_root_id;
|
||||||
|
// The index of the leave that holds this example in the oblivious case.
|
||||||
|
int oblivious_leaf_idx = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
const auto& current_node = config.nodes(node_id);
|
const auto& current_node = config.nodes(node_id);
|
||||||
switch (current_node.node_case()) {
|
switch (current_node.node_case()) {
|
||||||
case TreeNode::kLeaf: {
|
case TreeNode::kLeaf: {
|
||||||
return node_id;
|
return node_id + oblivious_leaf_idx;
|
||||||
}
|
}
|
||||||
case TreeNode::kDenseFloatBinarySplit: {
|
case TreeNode::kDenseFloatBinarySplit: {
|
||||||
const auto& split = current_node.dense_float_binary_split();
|
const auto& split = current_node.dense_float_binary_split();
|
||||||
@ -100,6 +101,16 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case TreeNode::kObliviousDenseFloatBinarySplit: {
|
||||||
|
const auto& split = current_node.oblivious_dense_float_binary_split();
|
||||||
|
oblivious_leaf_idx <<= 1;
|
||||||
|
if (example.dense_float_features[split.feature_column()] >
|
||||||
|
split.threshold()) {
|
||||||
|
oblivious_leaf_idx++;
|
||||||
|
}
|
||||||
|
node_id++;
|
||||||
|
break;
|
||||||
|
}
|
||||||
case TreeNode::NODE_NOT_SET: {
|
case TreeNode::NODE_NOT_SET: {
|
||||||
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
|
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
|
||||||
break;
|
break;
|
||||||
@ -165,6 +176,11 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children,
|
|||||||
split->set_right_id(*++children_it);
|
split->set_right_id(*++children_it);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case TreeNode::kObliviousDenseFloatBinarySplit: {
|
||||||
|
LOG(QFATAL)
|
||||||
|
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
|
||||||
|
break;
|
||||||
|
}
|
||||||
case TreeNode::NODE_NOT_SET: {
|
case TreeNode::NODE_NOT_SET: {
|
||||||
LOG(QFATAL) << "A non-set node cannot have children.";
|
LOG(QFATAL) << "A non-set node cannot have children.";
|
||||||
break;
|
break;
|
||||||
@ -199,6 +215,11 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
|
|||||||
const auto& split = node.categorical_id_set_membership_binary_split();
|
const auto& split = node.categorical_id_set_membership_binary_split();
|
||||||
return {split.left_id(), split.right_id()};
|
return {split.left_id(), split.right_id()};
|
||||||
}
|
}
|
||||||
|
case TreeNode::kObliviousDenseFloatBinarySplit: {
|
||||||
|
LOG(QFATAL)
|
||||||
|
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
case TreeNode::NODE_NOT_SET: {
|
case TreeNode::NODE_NOT_SET: {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
@ -57,6 +57,7 @@ REGISTER_OP("GrowTreeEnsemble")
|
|||||||
.Input("learning_rate: float")
|
.Input("learning_rate: float")
|
||||||
.Input("dropout_seed: int64")
|
.Input("dropout_seed: int64")
|
||||||
.Input("max_tree_depth: int32")
|
.Input("max_tree_depth: int32")
|
||||||
|
.Input("weak_learner_type: int32")
|
||||||
.Input("partition_ids: num_handlers * int32")
|
.Input("partition_ids: num_handlers * int32")
|
||||||
.Input("gains: num_handlers * float")
|
.Input("gains: num_handlers * float")
|
||||||
.Input("splits: num_handlers * string")
|
.Input("splits: num_handlers * string")
|
||||||
@ -82,6 +83,7 @@ tree_ensemble_handle: Handle to the ensemble variable.
|
|||||||
stamp_token: Stamp token for validating operation consistency.
|
stamp_token: Stamp token for validating operation consistency.
|
||||||
next_stamp_token: Stamp token to be used for the next iteration.
|
next_stamp_token: Stamp token to be used for the next iteration.
|
||||||
learning_rate: Scalar learning rate.
|
learning_rate: Scalar learning rate.
|
||||||
|
weak_learner_type: The type of weak learner to use.
|
||||||
partition_ids: List of Rank 1 Tensors containing partition Id per candidate.
|
partition_ids: List of Rank 1 Tensors containing partition Id per candidate.
|
||||||
gains: List of Rank 1 Tensors containing gains per candidate.
|
gains: List of Rank 1 Tensors containing gains per candidate.
|
||||||
splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate.
|
splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate.
|
||||||
|
@ -19,8 +19,6 @@ message SplitInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message ObliviousSplitInfo {
|
message ObliviousSplitInfo {
|
||||||
// The split node with the feature_column and threshold defined.
|
|
||||||
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
|
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
|
||||||
// The new leaves of the tree.
|
repeated tensorflow.boosted_trees.trees.Leaf children = 2;
|
||||||
repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
|
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ message TreeNode {
|
|||||||
CategoricalIdBinarySplit categorical_id_binary_split = 5;
|
CategoricalIdBinarySplit categorical_id_binary_split = 5;
|
||||||
CategoricalIdSetMembershipBinarySplit
|
CategoricalIdSetMembershipBinarySplit
|
||||||
categorical_id_set_membership_binary_split = 6;
|
categorical_id_set_membership_binary_split = 6;
|
||||||
|
ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
|
||||||
}
|
}
|
||||||
TreeNodeMetadata node_metadata = 777;
|
TreeNodeMetadata node_metadata = 777;
|
||||||
}
|
}
|
||||||
@ -26,6 +27,9 @@ message TreeNodeMetadata {
|
|||||||
|
|
||||||
// The original leaf node before this node was split.
|
// The original leaf node before this node was split.
|
||||||
Leaf original_leaf = 2;
|
Leaf original_leaf = 2;
|
||||||
|
|
||||||
|
// The original layer of leaves before that layer was converted to a split.
|
||||||
|
repeated Leaf original_oblivious_leaves = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Leaves can either hold dense or sparse information.
|
// Leaves can either hold dense or sparse information.
|
||||||
@ -101,6 +105,17 @@ message CategoricalIdSetMembershipBinarySplit {
|
|||||||
int32 right_id = 4;
|
int32 right_id = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Split rule for dense float features in the oblivious case.
|
||||||
|
message ObliviousDenseFloatBinarySplit {
|
||||||
|
// Float feature column and split threshold describing
|
||||||
|
// the rule feature <= threshold.
|
||||||
|
int32 feature_column = 1;
|
||||||
|
float threshold = 2;
|
||||||
|
// We don't store children ids, because either the next node represents the
|
||||||
|
// whole next layer of the tree or starting with the next node we only have
|
||||||
|
// leaves.
|
||||||
|
}
|
||||||
|
|
||||||
// DecisionTreeConfig describes a list of connected nodes.
|
// DecisionTreeConfig describes a list of connected nodes.
|
||||||
// Node 0 must be the root and can carry any payload including a leaf
|
// Node 0 must be the root and can carry any payload including a leaf
|
||||||
// in the case of representing the bias.
|
// in the case of representing the bias.
|
||||||
|
@ -91,6 +91,27 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
|
|||||||
return split.SerializeToString()
|
return split.SerializeToString()
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_dense_oblivious_split_info(fc, threshold, leave_weights):
|
||||||
|
split_str = """
|
||||||
|
split_node {
|
||||||
|
oblivious_dense_float_binary_split {
|
||||||
|
feature_column: %d
|
||||||
|
threshold: %f
|
||||||
|
}
|
||||||
|
}""" % (fc, threshold)
|
||||||
|
for weight in leave_weights:
|
||||||
|
split_str += """
|
||||||
|
children {
|
||||||
|
vector {
|
||||||
|
value: %f
|
||||||
|
}
|
||||||
|
}""" % (
|
||||||
|
weight)
|
||||||
|
split = split_info_pb2.ObliviousSplitInfo()
|
||||||
|
text_format.Merge(split_str, split)
|
||||||
|
return split.SerializeToString()
|
||||||
|
|
||||||
|
|
||||||
def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight):
|
def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight):
|
||||||
split_str = """
|
split_str = """
|
||||||
split_node {
|
split_node {
|
||||||
@ -324,7 +345,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the simpler split from handler 1 to be chosen.
|
# Expect the simpler split from handler 1 to be chosen.
|
||||||
@ -383,6 +405,115 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(stats.attempted_layers, 1)
|
self.assertEqual(stats.attempted_layers, 1)
|
||||||
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
||||||
|
|
||||||
|
def testGrowEmptyEnsembleObliviousCase(self):
|
||||||
|
"""Test growing an empty ensemble in the oblivious case."""
|
||||||
|
with self.test_session() as session:
|
||||||
|
# Create empty ensemble.
|
||||||
|
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||||
|
tree_ensemble_handle = model_ops.tree_ensemble_variable(
|
||||||
|
stamp_token=0,
|
||||||
|
tree_ensemble_config=tree_ensemble_config.SerializeToString(),
|
||||||
|
name="tree_ensemble")
|
||||||
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
|
|
||||||
|
# Prepare learner config.
|
||||||
|
learner_config = _gen_learner_config(
|
||||||
|
num_classes=2,
|
||||||
|
l1_reg=0,
|
||||||
|
l2_reg=0,
|
||||||
|
tree_complexity=0,
|
||||||
|
max_depth=1,
|
||||||
|
min_node_weight=0,
|
||||||
|
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
|
||||||
|
growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
|
||||||
|
|
||||||
|
# Prepare handler inputs.
|
||||||
|
# Note that handlers 1 & 3 have the same gain but different splits.
|
||||||
|
handler1_partitions = np.array([0], dtype=np.int32)
|
||||||
|
handler1_gains = np.array([7.62], dtype=np.float32)
|
||||||
|
handler1_split = [
|
||||||
|
_gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143])
|
||||||
|
]
|
||||||
|
handler2_partitions = np.array([0], dtype=np.int32)
|
||||||
|
handler2_gains = np.array([0.63], dtype=np.float32)
|
||||||
|
handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])]
|
||||||
|
handler3_partitions = np.array([0], dtype=np.int32)
|
||||||
|
handler3_gains = np.array([7.62], dtype=np.float32)
|
||||||
|
handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])]
|
||||||
|
|
||||||
|
# Grow tree ensemble.
|
||||||
|
grow_op = training_ops.grow_tree_ensemble(
|
||||||
|
tree_ensemble_handle,
|
||||||
|
stamp_token=0,
|
||||||
|
next_stamp_token=1,
|
||||||
|
learning_rate=0.1,
|
||||||
|
partition_ids=[
|
||||||
|
handler1_partitions, handler2_partitions, handler3_partitions
|
||||||
|
],
|
||||||
|
gains=[handler1_gains, handler2_gains, handler3_gains],
|
||||||
|
splits=[handler1_split, handler2_split, handler3_split],
|
||||||
|
learner_config=learner_config.SerializeToString(),
|
||||||
|
dropout_seed=123,
|
||||||
|
center_bias=True,
|
||||||
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
|
||||||
|
session.run(grow_op)
|
||||||
|
|
||||||
|
# Expect the split with bigger handler_id, i.e. handler 3 to be chosen.
|
||||||
|
# The grown tree should be finalized as max tree depth is 1.
|
||||||
|
new_stamp, serialized = session.run(
|
||||||
|
model_ops.tree_ensemble_serialize(tree_ensemble_handle))
|
||||||
|
stats = session.run(
|
||||||
|
training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
|
||||||
|
tree_ensemble_config.ParseFromString(serialized)
|
||||||
|
expected_result = """
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
oblivious_dense_float_binary_split {
|
||||||
|
feature_column: 0
|
||||||
|
threshold: 7
|
||||||
|
}
|
||||||
|
node_metadata {
|
||||||
|
gain: 7.62
|
||||||
|
original_oblivious_leaves {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: -4.375
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: 7.143
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 0.1
|
||||||
|
tree_metadata {
|
||||||
|
num_tree_weight_updates: 1
|
||||||
|
num_layers_grown: 1
|
||||||
|
is_finalized: true
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 1
|
||||||
|
num_layers_attempted: 1
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self.assertEqual(new_stamp, 1)
|
||||||
|
self.assertEqual(stats.num_trees, 1)
|
||||||
|
self.assertEqual(stats.num_layers, 1)
|
||||||
|
self.assertEqual(stats.active_tree, 1)
|
||||||
|
self.assertEqual(stats.active_layer, 1)
|
||||||
|
self.assertEqual(stats.attempted_trees, 1)
|
||||||
|
self.assertEqual(stats.attempted_layers, 1)
|
||||||
|
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
||||||
|
|
||||||
def testGrowExistingEnsembleTreeNotFinalized(self):
|
def testGrowExistingEnsembleTreeNotFinalized(self):
|
||||||
"""Test growing an existing ensemble with the last tree not finalized."""
|
"""Test growing an existing ensemble with the last tree not finalized."""
|
||||||
with self.test_session() as session:
|
with self.test_session() as session:
|
||||||
@ -476,7 +607,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the split for partition 1 to be chosen from handler 1 and
|
# Expect the split for partition 1 to be chosen from handler 1 and
|
||||||
@ -661,7 +793,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect a new tree to be added with the split from handler 1.
|
# Expect a new tree to be added with the split from handler 1.
|
||||||
@ -798,7 +931,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the ensemble to be empty.
|
# Expect the ensemble to be empty.
|
||||||
@ -869,7 +1003,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the simpler split from handler 1 to be chosen.
|
# Expect the simpler split from handler 1 to be chosen.
|
||||||
@ -971,7 +1106,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the split from handler 2 to be chosen despite the negative gain.
|
# Expect the split from handler 2 to be chosen despite the negative gain.
|
||||||
@ -1053,7 +1189,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the ensemble to be empty as post-pruning will prune
|
# Expect the ensemble to be empty as post-pruning will prune
|
||||||
@ -1120,7 +1257,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the split from handler 2 to be chosen despite the negative gain.
|
# Expect the split from handler 2 to be chosen despite the negative gain.
|
||||||
@ -1200,7 +1338,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the negative gain split of partition 1 to be pruned and the
|
# Expect the negative gain split of partition 1 to be pruned and the
|
||||||
@ -1371,7 +1510,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect the split for partition 1 to be chosen from handler 1 and
|
# Expect the split for partition 1 to be chosen from handler 1 and
|
||||||
@ -1470,6 +1610,193 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(stats.attempted_layers, 2)
|
self.assertEqual(stats.attempted_layers, 2)
|
||||||
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
||||||
|
|
||||||
|
def testGrowEnsembleTreeLayerByLayerObliviousCase(self):
|
||||||
|
"""Test growing an existing ensemble with the last tree not finalized."""
|
||||||
|
with self.test_session() as session:
|
||||||
|
# Create existing ensemble with one root split
|
||||||
|
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
|
||||||
|
text_format.Merge(
|
||||||
|
"""
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
oblivious_dense_float_binary_split {
|
||||||
|
feature_column: 4
|
||||||
|
threshold: 7
|
||||||
|
}
|
||||||
|
node_metadata {
|
||||||
|
gain: 7.62
|
||||||
|
original_oblivious_leaves {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: 7.143
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: -4.375
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 0.1
|
||||||
|
tree_metadata {
|
||||||
|
num_tree_weight_updates: 1
|
||||||
|
num_layers_grown: 1
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 1
|
||||||
|
num_layers_attempted: 1
|
||||||
|
}
|
||||||
|
""", tree_ensemble_config)
|
||||||
|
tree_ensemble_handle = model_ops.tree_ensemble_variable(
|
||||||
|
stamp_token=0,
|
||||||
|
tree_ensemble_config=tree_ensemble_config.SerializeToString(),
|
||||||
|
name="tree_ensemble")
|
||||||
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
|
|
||||||
|
# Prepare learner config.
|
||||||
|
learner_config = _gen_learner_config(
|
||||||
|
num_classes=2,
|
||||||
|
l1_reg=0,
|
||||||
|
l2_reg=0,
|
||||||
|
tree_complexity=0,
|
||||||
|
max_depth=3,
|
||||||
|
min_node_weight=0,
|
||||||
|
pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
|
||||||
|
growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
|
||||||
|
|
||||||
|
# Prepare handler inputs.
|
||||||
|
handler1_partitions = np.array([0], dtype=np.int32)
|
||||||
|
handler1_gains = np.array([1.4], dtype=np.float32)
|
||||||
|
handler1_split = [
|
||||||
|
_gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5])
|
||||||
|
]
|
||||||
|
handler2_partitions = np.array([0], dtype=np.int32)
|
||||||
|
handler2_gains = np.array([2.7], dtype=np.float32)
|
||||||
|
handler2_split = [
|
||||||
|
_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]),
|
||||||
|
]
|
||||||
|
handler3_partitions = np.array([0], dtype=np.int32)
|
||||||
|
handler3_gains = np.array([1.7], dtype=np.float32)
|
||||||
|
handler3_split = [
|
||||||
|
_gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1])
|
||||||
|
]
|
||||||
|
|
||||||
|
# Grow tree ensemble layer by layer.
|
||||||
|
grow_op = training_ops.grow_tree_ensemble(
|
||||||
|
tree_ensemble_handle,
|
||||||
|
stamp_token=0,
|
||||||
|
next_stamp_token=1,
|
||||||
|
learning_rate=0.1,
|
||||||
|
partition_ids=[
|
||||||
|
handler1_partitions, handler2_partitions, handler3_partitions
|
||||||
|
],
|
||||||
|
gains=[handler1_gains, handler2_gains, handler3_gains],
|
||||||
|
splits=[handler1_split, handler2_split, handler3_split],
|
||||||
|
learner_config=learner_config.SerializeToString(),
|
||||||
|
dropout_seed=123,
|
||||||
|
center_bias=True,
|
||||||
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
|
||||||
|
session.run(grow_op)
|
||||||
|
|
||||||
|
# Expect the split for partition 1 to be chosen from handler 1 and
|
||||||
|
# the split for partition 2 to be chosen from handler 2.
|
||||||
|
# The grown tree should not be finalized as max tree depth is 3 and
|
||||||
|
# it's only grown 2 layers.
|
||||||
|
# The partition 1 split weights get added to original leaf weight 7.143.
|
||||||
|
# The partition 2 split weights get added to original leaf weight -4.375.
|
||||||
|
new_stamp, serialized = session.run(
|
||||||
|
model_ops.tree_ensemble_serialize(tree_ensemble_handle))
|
||||||
|
stats = session.run(
|
||||||
|
training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
|
||||||
|
tree_ensemble_config.ParseFromString(serialized)
|
||||||
|
expected_result = """
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
oblivious_dense_float_binary_split {
|
||||||
|
feature_column: 4
|
||||||
|
threshold: 7
|
||||||
|
}
|
||||||
|
node_metadata {
|
||||||
|
gain: 7.62
|
||||||
|
original_oblivious_leaves {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
oblivious_dense_float_binary_split {
|
||||||
|
feature_column: 0
|
||||||
|
threshold: 0.23
|
||||||
|
}
|
||||||
|
node_metadata {
|
||||||
|
gain: 2.7
|
||||||
|
original_oblivious_leaves {
|
||||||
|
vector {
|
||||||
|
value: 7.143
|
||||||
|
}
|
||||||
|
}
|
||||||
|
original_oblivious_leaves {
|
||||||
|
vector {
|
||||||
|
value: -4.375
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: 6.543
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: 7.383
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: -4.075
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
vector {
|
||||||
|
value: -3.975
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 0.1
|
||||||
|
tree_metadata {
|
||||||
|
num_tree_weight_updates: 1
|
||||||
|
num_layers_grown: 2
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 1
|
||||||
|
num_layers_attempted: 2
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self.assertEqual(new_stamp, 1)
|
||||||
|
self.assertEqual(stats.num_trees, 0)
|
||||||
|
self.assertEqual(stats.num_layers, 2)
|
||||||
|
self.assertEqual(stats.active_tree, 1)
|
||||||
|
self.assertEqual(stats.active_layer, 2)
|
||||||
|
self.assertEqual(stats.attempted_trees, 1)
|
||||||
|
self.assertEqual(stats.attempted_layers, 2)
|
||||||
|
self.assertProtoEquals(expected_result, tree_ensemble_config)
|
||||||
|
|
||||||
def testGrowExistingEnsembleTreeFinalizedWithDropout(self):
|
def testGrowExistingEnsembleTreeFinalizedWithDropout(self):
|
||||||
"""Test growing an existing ensemble with the last tree finalized."""
|
"""Test growing an existing ensemble with the last tree finalized."""
|
||||||
with self.test_session() as session:
|
with self.test_session() as session:
|
||||||
@ -1575,7 +1902,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
# Expect a new tree to be added with the split from handler 1.
|
# Expect a new tree to be added with the split from handler 1.
|
||||||
@ -1700,7 +2028,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learner_config=learner_config.SerializeToString(),
|
learner_config=learner_config.SerializeToString(),
|
||||||
dropout_seed=123,
|
dropout_seed=123,
|
||||||
center_bias=True,
|
center_bias=True,
|
||||||
max_tree_depth=learner_config.constraints.max_tree_depth)
|
max_tree_depth=learner_config.constraints.max_tree_depth,
|
||||||
|
weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
|
||||||
session.run(grow_op)
|
session.run(grow_op)
|
||||||
|
|
||||||
_, serialized = session.run(
|
_, serialized = session.run(
|
||||||
|
@ -1076,7 +1076,8 @@ class GradientBoostedDecisionTreeModel(object):
|
|||||||
learner_config=self._learner_config_serialized,
|
learner_config=self._learner_config_serialized,
|
||||||
dropout_seed=dropout_seed,
|
dropout_seed=dropout_seed,
|
||||||
center_bias=self._center_bias,
|
center_bias=self._center_bias,
|
||||||
max_tree_depth=self._max_tree_depth)
|
max_tree_depth=self._max_tree_depth,
|
||||||
|
weak_learner_type=self._learner_config.weak_learner_type)
|
||||||
|
|
||||||
def _grow_ensemble_not_ready_fn():
|
def _grow_ensemble_not_ready_fn():
|
||||||
# Don't grow the ensemble, just update the stamp.
|
# Don't grow the ensemble, just update the stamp.
|
||||||
@ -1091,7 +1092,8 @@ class GradientBoostedDecisionTreeModel(object):
|
|||||||
learner_config=self._learner_config_serialized,
|
learner_config=self._learner_config_serialized,
|
||||||
dropout_seed=dropout_seed,
|
dropout_seed=dropout_seed,
|
||||||
center_bias=self._center_bias,
|
center_bias=self._center_bias,
|
||||||
max_tree_depth=self._max_tree_depth)
|
max_tree_depth=self._max_tree_depth,
|
||||||
|
weak_learner_type=self._learner_config.weak_learner_type)
|
||||||
|
|
||||||
def _grow_ensemble_fn():
|
def _grow_ensemble_fn():
|
||||||
# Conditionally grow an ensemble depending on whether the splits
|
# Conditionally grow an ensemble depending on whether the splits
|
||||||
|
Loading…
Reference in New Issue
Block a user