Enable equality split for UpdateEnsembleV2.
PiperOrigin-RevId: 263835411
This commit is contained in:
parent
e883e8b013
commit
1d825cff12
@ -46,6 +46,7 @@ cc_library(
|
|||||||
srcs = ["resources.cc"],
|
srcs = ["resources.cc"],
|
||||||
hdrs = ["resources.h"],
|
hdrs = ["resources.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":tree_helper",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
@ -95,6 +96,7 @@ tf_kernel_library(
|
|||||||
":tree_helper",
|
":tree_helper",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -48,6 +48,18 @@ message SparseVector {
|
|||||||
repeated float value = 2;
|
repeated float value = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum SplitTypeWithDefault {
|
||||||
|
INEQUALITY_DEFAULT_LEFT = 0;
|
||||||
|
INEQUALITY_DEFAULT_RIGHT = 1;
|
||||||
|
EQUALITY_DEFAULT_RIGHT = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum DefaultDirection {
|
||||||
|
// Left is the default direction.
|
||||||
|
DEFAULT_LEFT = 0;
|
||||||
|
DEFAULT_RIGHT = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message BucketizedSplit {
|
message BucketizedSplit {
|
||||||
// Float feature column and split threshold describing
|
// Float feature column and split threshold describing
|
||||||
// the rule feature <= threshold.
|
// the rule feature <= threshold.
|
||||||
@ -56,11 +68,6 @@ message BucketizedSplit {
|
|||||||
// If feature column is multivalent, this holds the index of the dimension
|
// If feature column is multivalent, this holds the index of the dimension
|
||||||
// for the split. Defaults to 0.
|
// for the split. Defaults to 0.
|
||||||
int32 dimension_id = 5;
|
int32 dimension_id = 5;
|
||||||
enum DefaultDirection {
|
|
||||||
// Left is the default direction.
|
|
||||||
DEFAULT_LEFT = 0;
|
|
||||||
DEFAULT_RIGHT = 1;
|
|
||||||
}
|
|
||||||
// default direction for missing values.
|
// default direction for missing values.
|
||||||
DefaultDirection default_direction = 6;
|
DefaultDirection default_direction = 6;
|
||||||
|
|
||||||
@ -75,6 +82,9 @@ message CategoricalSplit {
|
|||||||
// value.
|
// value.
|
||||||
int32 feature_id = 1;
|
int32 feature_id = 1;
|
||||||
int32 value = 2;
|
int32 value = 2;
|
||||||
|
// If feature column is multivalent, this holds the index of the dimension
|
||||||
|
// for the split. Defaults to 0.
|
||||||
|
int32 dimension_id = 5;
|
||||||
|
|
||||||
// Node children indexing into a contiguous
|
// Node children indexing into a contiguous
|
||||||
// vector of nodes starting from the root.
|
// vector of nodes starting from the root.
|
||||||
|
@ -14,8 +14,10 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
||||||
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
@ -265,11 +267,50 @@ int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
|
void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
|
||||||
const int32 tree_id, const int32 node_id, const int32 feature_id,
|
const int32 tree_id,
|
||||||
const int32 dimension_id, const int32 threshold, const float gain,
|
const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
|
||||||
const float left_contrib, const float right_contrib, int32* left_node_id,
|
int32* left_node_id, int32* right_node_id) {
|
||||||
int32* right_node_id) {
|
const auto candidate = split_entry.second;
|
||||||
|
auto* node = AddLeafNodes(tree_id, split_entry, left_node_id, right_node_id);
|
||||||
|
auto* new_split = node->mutable_bucketized_split();
|
||||||
|
new_split->set_feature_id(candidate.feature_idx);
|
||||||
|
new_split->set_threshold(candidate.threshold);
|
||||||
|
new_split->set_dimension_id(candidate.dimension_id);
|
||||||
|
new_split->set_left_id(*left_node_id);
|
||||||
|
new_split->set_right_id(*right_node_id);
|
||||||
|
|
||||||
|
boosted_trees::SplitTypeWithDefault split_type_with_default;
|
||||||
|
bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
|
||||||
|
candidate.split_type, &split_type_with_default);
|
||||||
|
DCHECK(parsed);
|
||||||
|
if (split_type_with_default == boosted_trees::INEQUALITY_DEFAULT_RIGHT) {
|
||||||
|
new_split->set_default_direction(boosted_trees::DEFAULT_RIGHT);
|
||||||
|
} else {
|
||||||
|
new_split->set_default_direction(boosted_trees::DEFAULT_LEFT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void BoostedTreesEnsembleResource::AddCategoricalSplitNode(
|
||||||
|
const int32 tree_id,
|
||||||
|
const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
|
||||||
|
int32* left_node_id, int32* right_node_id) {
|
||||||
|
const auto candidate = split_entry.second;
|
||||||
|
auto* node = AddLeafNodes(tree_id, split_entry, left_node_id, right_node_id);
|
||||||
|
auto* new_split = node->mutable_categorical_split();
|
||||||
|
new_split->set_feature_id(candidate.feature_idx);
|
||||||
|
new_split->set_value(candidate.threshold);
|
||||||
|
new_split->set_dimension_id(candidate.dimension_id);
|
||||||
|
new_split->set_left_id(*left_node_id);
|
||||||
|
new_split->set_right_id(*right_node_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
boosted_trees::Node* BoostedTreesEnsembleResource::AddLeafNodes(
|
||||||
|
const int32 tree_id,
|
||||||
|
const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
|
||||||
|
int32* left_node_id, int32* right_node_id) {
|
||||||
auto* tree = tree_ensemble_->mutable_trees(tree_id);
|
auto* tree = tree_ensemble_->mutable_trees(tree_id);
|
||||||
|
const auto node_id = split_entry.first;
|
||||||
|
const auto candidate = split_entry.second;
|
||||||
auto* node = tree->mutable_nodes(node_id);
|
auto* node = tree->mutable_nodes(node_id);
|
||||||
DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
|
DCHECK_EQ(node->node_case(), boosted_trees::Node::kLeaf);
|
||||||
float prev_node_value = node->leaf().scalar();
|
float prev_node_value = node->leaf().scalar();
|
||||||
@ -282,16 +323,13 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
|
|||||||
node->mutable_metadata()->mutable_original_leaf()->Swap(
|
node->mutable_metadata()->mutable_original_leaf()->Swap(
|
||||||
node->mutable_leaf());
|
node->mutable_leaf());
|
||||||
}
|
}
|
||||||
node->mutable_metadata()->set_gain(gain);
|
node->mutable_metadata()->set_gain(candidate.gain);
|
||||||
auto* new_split = node->mutable_bucketized_split();
|
|
||||||
new_split->set_feature_id(feature_id);
|
|
||||||
new_split->set_threshold(threshold);
|
|
||||||
new_split->set_dimension_id(dimension_id);
|
|
||||||
new_split->set_left_id(*left_node_id);
|
|
||||||
new_split->set_right_id(*right_node_id);
|
|
||||||
// TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
|
// TODO(npononareva): this is LAYER-BY-LAYER boosting; add WHOLE-TREE.
|
||||||
left_node->mutable_leaf()->set_scalar(prev_node_value + left_contrib);
|
left_node->mutable_leaf()->set_scalar(prev_node_value +
|
||||||
right_node->mutable_leaf()->set_scalar(prev_node_value + right_contrib);
|
candidate.left_node_contrib);
|
||||||
|
right_node->mutable_leaf()->set_scalar(prev_node_value +
|
||||||
|
candidate.right_node_contrib);
|
||||||
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BoostedTreesEnsembleResource::Reset() {
|
void BoostedTreesEnsembleResource::Reset() {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
|
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ namespace tensorflow {
|
|||||||
// Forward declaration for proto class TreeEnsemble
|
// Forward declaration for proto class TreeEnsemble
|
||||||
namespace boosted_trees {
|
namespace boosted_trees {
|
||||||
class TreeEnsemble;
|
class TreeEnsemble;
|
||||||
|
class Node;
|
||||||
} // namespace boosted_trees
|
} // namespace boosted_trees
|
||||||
|
|
||||||
// A StampedResource is a resource that has a stamp token associated with it.
|
// A StampedResource is a resource that has a stamp token associated with it.
|
||||||
@ -105,13 +107,17 @@ class BoostedTreesEnsembleResource : public StampedResource {
|
|||||||
// Adds new tree with one node to the ensemble and sets node's value to logits
|
// Adds new tree with one node to the ensemble and sets node's value to logits
|
||||||
int32 AddNewTreeWithLogits(const float weight, const float logits);
|
int32 AddNewTreeWithLogits(const float weight, const float logits);
|
||||||
|
|
||||||
// Grows the tree by adding a split and leaves.
|
// Grows the tree by adding a bucketized split and leaves.
|
||||||
void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id,
|
void AddBucketizedSplitNode(
|
||||||
const int32 feature_id, const int32 dimension_id,
|
const int32 tree_id,
|
||||||
const int32 threshold, const float gain,
|
const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
|
||||||
const float left_contrib,
|
int32* left_node_id, int32* right_node_id);
|
||||||
const float right_contrib, int32* left_node_id,
|
|
||||||
int32* right_node_id);
|
// Grows the tree by adding a categorical split and leaves.
|
||||||
|
void AddCategoricalSplitNode(
|
||||||
|
const int32 tree_id,
|
||||||
|
const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
|
||||||
|
int32* left_node_id, int32* right_node_id);
|
||||||
|
|
||||||
// Retrieves tree weights and returns as a vector.
|
// Retrieves tree weights and returns as a vector.
|
||||||
// It involves a copy, so should be called only sparingly (like once per
|
// It involves a copy, so should be called only sparingly (like once per
|
||||||
@ -167,6 +173,11 @@ class BoostedTreesEnsembleResource : public StampedResource {
|
|||||||
protobuf::Arena arena_;
|
protobuf::Arena arena_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
boosted_trees::TreeEnsemble* tree_ensemble_;
|
boosted_trees::TreeEnsemble* tree_ensemble_;
|
||||||
|
|
||||||
|
boosted_trees::Node* AddLeafNodes(
|
||||||
|
int32 tree_id,
|
||||||
|
const std::pair<int32, boosted_trees::SplitCandidate>& split_entry,
|
||||||
|
int32* left_node_id, int32* right_node_id);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,16 +20,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// TODO(tanzheny): Make these const as proto enum.
|
|
||||||
const char kInequalityDefaultLeft[] = "inequality_default_left";
|
|
||||||
const char kInequalityDefaultRight[] = "inequality_default_right";
|
|
||||||
const char kEqualityDefaultRight[] = "equality_default_right";
|
|
||||||
|
|
||||||
using Matrix =
|
using Matrix =
|
||||||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||||
@ -459,6 +455,12 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
|
|||||||
cum_hess.push_back(total_hess);
|
cum_hess.push_back(total_hess);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const string kInequalityDefaultLeft =
|
||||||
|
boosted_trees::SplitTypeWithDefault_Name(
|
||||||
|
boosted_trees::INEQUALITY_DEFAULT_LEFT);
|
||||||
|
const string kInequalityDefaultRight =
|
||||||
|
boosted_trees::SplitTypeWithDefault_Name(
|
||||||
|
boosted_trees::INEQUALITY_DEFAULT_RIGHT);
|
||||||
|
|
||||||
// Iterate from left to right, excluding default bucket.
|
// Iterate from left to right, excluding default bucket.
|
||||||
for (int bucket = 0; bucket < num_buckets; ++bucket) {
|
for (int bucket = 0; bucket < num_buckets; ++bucket) {
|
||||||
@ -491,6 +493,9 @@ class BoostedTreesCalculateBestFeatureSplitOp : public OpKernel {
|
|||||||
const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
|
const float l2, float* best_gain, int32* best_bucket, int32* best_f_dim,
|
||||||
string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
|
string* best_split_type, Eigen::VectorXf* best_contrib_for_left,
|
||||||
Eigen::VectorXf* best_contrib_for_right) {
|
Eigen::VectorXf* best_contrib_for_right) {
|
||||||
|
const string kEqualityDefaultRight =
|
||||||
|
boosted_trees::SplitTypeWithDefault_Name(
|
||||||
|
boosted_trees::EQUALITY_DEFAULT_RIGHT);
|
||||||
for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
|
for (int f_dim = 0; f_dim < feature_dims; ++f_dim) {
|
||||||
for (int bucket = 0; bucket < num_buckets; ++bucket) {
|
for (int bucket = 0; bucket < num_buckets; ++bucket) {
|
||||||
ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
|
ConstVectorMap stats_vec(&stats_summary(node_id, f_dim, bucket, 0),
|
||||||
@ -734,7 +739,8 @@ class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
|
|||||||
float best_gain = std::numeric_limits<float>::lowest();
|
float best_gain = std::numeric_limits<float>::lowest();
|
||||||
float best_bucket = 0;
|
float best_bucket = 0;
|
||||||
float best_f_dim = 0;
|
float best_f_dim = 0;
|
||||||
string best_split_type = kInequalityDefaultLeft;
|
string best_split_type = boosted_trees::SplitTypeWithDefault_Name(
|
||||||
|
boosted_trees::INEQUALITY_DEFAULT_LEFT);
|
||||||
float best_contrib_for_left = 0.0;
|
float best_contrib_for_left = 0.0;
|
||||||
float best_contrib_for_right = 0.0;
|
float best_contrib_for_right = 0.0;
|
||||||
// the sum of gradients including default bucket.
|
// the sum of gradients including default bucket.
|
||||||
@ -801,7 +807,8 @@ class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
|
|||||||
best_gain = gain_for_left + gain_for_right;
|
best_gain = gain_for_left + gain_for_right;
|
||||||
best_bucket = bucket_id;
|
best_bucket = bucket_id;
|
||||||
best_f_dim = feature_dim;
|
best_f_dim = feature_dim;
|
||||||
best_split_type = kInequalityDefaultRight;
|
best_split_type = boosted_trees::SplitTypeWithDefault_Name(
|
||||||
|
boosted_trees::INEQUALITY_DEFAULT_RIGHT);
|
||||||
best_contrib_for_left = contrib_for_left[0];
|
best_contrib_for_left = contrib_for_left[0];
|
||||||
best_contrib_for_right = contrib_for_right[0];
|
best_contrib_for_right = contrib_for_right[0];
|
||||||
}
|
}
|
||||||
@ -818,7 +825,8 @@ class BoostedTreesSparseCalculateBestFeatureSplitOp : public OpKernel {
|
|||||||
best_gain = gain_for_left + gain_for_right;
|
best_gain = gain_for_left + gain_for_right;
|
||||||
best_bucket = bucket_id;
|
best_bucket = bucket_id;
|
||||||
best_f_dim = feature_dim;
|
best_f_dim = feature_dim;
|
||||||
best_split_type = kInequalityDefaultLeft;
|
best_split_type = boosted_trees::SplitTypeWithDefault_Name(
|
||||||
|
boosted_trees::INEQUALITY_DEFAULT_LEFT);
|
||||||
best_contrib_for_left = contrib_for_left[0];
|
best_contrib_for_left = contrib_for_left[0];
|
||||||
best_contrib_for_right = contrib_for_right[0];
|
best_contrib_for_right = contrib_for_right[0];
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
#include "tensorflow/core/kernels/boosted_trees/resources.h"
|
||||||
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
|
||||||
#include "tensorflow/core/lib/core/refcount.h"
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
@ -26,19 +27,6 @@ namespace {
|
|||||||
constexpr float kLayerByLayerTreeWeight = 1.0;
|
constexpr float kLayerByLayerTreeWeight = 1.0;
|
||||||
constexpr float kMinDeltaForCenterBias = 0.01;
|
constexpr float kMinDeltaForCenterBias = 0.01;
|
||||||
|
|
||||||
// TODO(nponomareva, youngheek): consider using vector.
|
|
||||||
struct SplitCandidate {
|
|
||||||
SplitCandidate() {}
|
|
||||||
|
|
||||||
// Index in the list of the feature ids.
|
|
||||||
int64 feature_idx;
|
|
||||||
|
|
||||||
// Index in the tensor of node_ids for the feature with idx feature_idx.
|
|
||||||
int64 candidate_idx;
|
|
||||||
|
|
||||||
float gain;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 };
|
enum PruningMode { kNoPruning = 0, kPrePruning = 1, kPostPruning = 2 };
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -91,9 +79,10 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
const auto learning_rate = learning_rate_t->scalar<float>()();
|
const auto learning_rate = learning_rate_t->scalar<float>()();
|
||||||
|
|
||||||
// Find best splits for each active node.
|
// Find best splits for each active node.
|
||||||
std::map<int32, SplitCandidate> best_splits;
|
std::map<int32, boosted_trees::SplitCandidate> best_splits;
|
||||||
FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids,
|
FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
|
||||||
&best_splits);
|
thresholds_list, left_node_contribs,
|
||||||
|
right_node_contribs, feature_ids, &best_splits);
|
||||||
|
|
||||||
int32 current_tree =
|
int32 current_tree =
|
||||||
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
|
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
|
||||||
@ -113,17 +102,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
int32 node_id_start = ensemble_resource->GetNumNodes(current_tree);
|
int32 node_id_start = ensemble_resource->GetNumNodes(current_tree);
|
||||||
// Add the splits to the tree.
|
// Add the splits to the tree.
|
||||||
for (auto& split_entry : best_splits) {
|
for (auto& split_entry : best_splits) {
|
||||||
const int32 node_id = split_entry.first;
|
const float gain = split_entry.second.gain;
|
||||||
const SplitCandidate& candidate = split_entry.second;
|
|
||||||
|
|
||||||
const int64 feature_idx = candidate.feature_idx;
|
|
||||||
const int64 candidate_idx = candidate.candidate_idx;
|
|
||||||
|
|
||||||
const int32 feature_id = feature_ids(feature_idx);
|
|
||||||
const int32 threshold =
|
|
||||||
thresholds_list[feature_idx].vec<int32>()(candidate_idx);
|
|
||||||
const float gain = gains_list[feature_idx].vec<float>()(candidate_idx);
|
|
||||||
|
|
||||||
if (pruning_mode_ == kPrePruning) {
|
if (pruning_mode_ == kPrePruning) {
|
||||||
// Don't consider negative splits if we're pre-pruning the tree.
|
// Don't consider negative splits if we're pre-pruning the tree.
|
||||||
// Note that zero-gain splits are acceptable.
|
// Note that zero-gain splits are acceptable.
|
||||||
@ -131,22 +110,13 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// For now assume that the weights vectors are one dimensional.
|
|
||||||
// TODO(nponomareva): change here for multiclass.
|
|
||||||
const float left_contrib =
|
|
||||||
learning_rate *
|
|
||||||
left_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
|
|
||||||
const float right_contrib =
|
|
||||||
learning_rate *
|
|
||||||
right_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
|
|
||||||
|
|
||||||
// unused.
|
// unused.
|
||||||
int32 left_node_id;
|
int32 left_node_id;
|
||||||
int32 right_node_id;
|
int32 right_node_id;
|
||||||
|
|
||||||
ensemble_resource->AddBucketizedSplitNode(
|
ensemble_resource->AddBucketizedSplitNode(current_tree, split_entry,
|
||||||
current_tree, node_id, feature_id, 0, threshold, gain, left_contrib,
|
&left_node_id, &right_node_id);
|
||||||
right_contrib, &left_node_id, &right_node_id);
|
|
||||||
split_happened = true;
|
split_happened = true;
|
||||||
}
|
}
|
||||||
int32 node_id_end = ensemble_resource->GetNumNodes(current_tree);
|
int32 node_id_end = ensemble_resource->GetNumNodes(current_tree);
|
||||||
@ -196,14 +166,22 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
// Helper method which effectively does a reduce over all split candidates
|
// Helper method which effectively does a reduce over all split candidates
|
||||||
// and finds the best split for each node.
|
// and finds the best split for each node.
|
||||||
void FindBestSplitsPerNode(
|
void FindBestSplitsPerNode(
|
||||||
OpKernelContext* const context, const OpInputList& node_ids_list,
|
OpKernelContext* const context, const float learning_rate,
|
||||||
const OpInputList& gains_list,
|
const OpInputList& node_ids_list, const OpInputList& gains_list,
|
||||||
|
const OpInputList& thresholds_list,
|
||||||
|
const OpInputList& left_node_contribs_list,
|
||||||
|
const OpInputList& right_node_contribs_list,
|
||||||
const TTypes<const int32>::Vec& feature_ids,
|
const TTypes<const int32>::Vec& feature_ids,
|
||||||
std::map<int32, SplitCandidate>* best_split_per_node) {
|
std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
|
||||||
// Find best split per node going through every feature candidate.
|
// Find best split per node going through every feature candidate.
|
||||||
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
|
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
|
||||||
const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
|
const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
|
||||||
const auto& gains = gains_list[feature_idx].vec<float>();
|
const auto& gains = gains_list[feature_idx].vec<float>();
|
||||||
|
const auto& thresholds = thresholds_list[feature_idx].vec<int32>();
|
||||||
|
const auto& left_node_contribs =
|
||||||
|
left_node_contribs_list[feature_idx].matrix<float>();
|
||||||
|
const auto& right_node_contribs =
|
||||||
|
right_node_contribs_list[feature_idx].matrix<float>();
|
||||||
|
|
||||||
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
|
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
|
||||||
++candidate_idx) {
|
++candidate_idx) {
|
||||||
@ -212,16 +190,24 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
|||||||
const auto& gain = gains(candidate_idx);
|
const auto& gain = gains(candidate_idx);
|
||||||
|
|
||||||
auto best_split_it = best_split_per_node->find(node_id);
|
auto best_split_it = best_split_per_node->find(node_id);
|
||||||
SplitCandidate candidate;
|
boosted_trees::SplitCandidate candidate;
|
||||||
candidate.feature_idx = feature_idx;
|
candidate.feature_idx = feature_ids(feature_idx);
|
||||||
candidate.candidate_idx = candidate_idx;
|
candidate.candidate_idx = candidate_idx;
|
||||||
candidate.gain = gain;
|
candidate.gain = gain;
|
||||||
|
candidate.dimension_id = 0;
|
||||||
|
candidate.threshold = thresholds(candidate_idx);
|
||||||
|
candidate.left_node_contrib =
|
||||||
|
learning_rate * left_node_contribs(candidate_idx, 0);
|
||||||
|
candidate.right_node_contrib =
|
||||||
|
learning_rate * right_node_contribs(candidate_idx, 0);
|
||||||
|
candidate.split_type = boosted_trees::SplitTypeWithDefault_Name(
|
||||||
|
boosted_trees::INEQUALITY_DEFAULT_LEFT);
|
||||||
|
|
||||||
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
||||||
GainsAreEqual(gain, best_split_it->second.gain))) {
|
GainsAreEqual(gain, best_split_it->second.gain))) {
|
||||||
const auto best_candidate = (*best_split_per_node)[node_id];
|
const auto best_candidate = (*best_split_per_node)[node_id];
|
||||||
const int32 best_feature_id = feature_ids(best_candidate.feature_idx);
|
const int32 best_feature_id = best_candidate.feature_idx;
|
||||||
const int32 feature_id = feature_ids(candidate.feature_idx);
|
const int32 feature_id = candidate.feature_idx;
|
||||||
VLOG(2) << "Breaking ties on feature ids and buckets";
|
VLOG(2) << "Breaking ties on feature ids and buckets";
|
||||||
// Breaking ties deterministically.
|
// Breaking ties deterministically.
|
||||||
if (feature_id < best_feature_id) {
|
if (feature_id < best_feature_id) {
|
||||||
@ -299,9 +285,11 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
static_cast<PruningMode>(pruning_mode_t->scalar<int32>()());
|
static_cast<PruningMode>(pruning_mode_t->scalar<int32>()());
|
||||||
|
|
||||||
// Find best splits for each active node.
|
// Find best splits for each active node.
|
||||||
std::map<int32, SplitCandidate> best_splits;
|
std::map<int32, boosted_trees::SplitCandidate> best_splits;
|
||||||
FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids,
|
FindBestSplitsPerNode(context, learning_rate, node_ids_list, gains_list,
|
||||||
&best_splits);
|
thresholds_list, dimension_ids_list,
|
||||||
|
left_node_contribs, right_node_contribs,
|
||||||
|
split_types_list, feature_ids, &best_splits);
|
||||||
|
|
||||||
int32 current_tree =
|
int32 current_tree =
|
||||||
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
|
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
|
||||||
@ -321,19 +309,8 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
int32 node_id_start = ensemble_resource->GetNumNodes(current_tree);
|
int32 node_id_start = ensemble_resource->GetNumNodes(current_tree);
|
||||||
// Add the splits to the tree.
|
// Add the splits to the tree.
|
||||||
for (auto& split_entry : best_splits) {
|
for (auto& split_entry : best_splits) {
|
||||||
const int32 node_id = split_entry.first;
|
const float gain = split_entry.second.gain;
|
||||||
const SplitCandidate& candidate = split_entry.second;
|
const string split_type = split_entry.second.split_type;
|
||||||
|
|
||||||
const int64 feature_idx = candidate.feature_idx;
|
|
||||||
const int32 feature_id = feature_ids(feature_idx);
|
|
||||||
|
|
||||||
const int64 candidate_idx = candidate.candidate_idx;
|
|
||||||
|
|
||||||
const int32 dimension_id =
|
|
||||||
dimension_ids_list[feature_idx].vec<int32>()(candidate_idx);
|
|
||||||
const int32 threshold =
|
|
||||||
thresholds_list[feature_idx].vec<int32>()(candidate_idx);
|
|
||||||
const float gain = gains_list[feature_idx].vec<float>()(candidate_idx);
|
|
||||||
|
|
||||||
if (pruning_mode == kPrePruning) {
|
if (pruning_mode == kPrePruning) {
|
||||||
// Don't consider negative splits if we're pre-pruning the tree.
|
// Don't consider negative splits if we're pre-pruning the tree.
|
||||||
@ -343,22 +320,23 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(crawles): change here for multiclass.
|
|
||||||
const float left_contrib =
|
|
||||||
learning_rate *
|
|
||||||
left_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
|
|
||||||
const float right_contrib =
|
|
||||||
learning_rate *
|
|
||||||
right_node_contribs[feature_idx].matrix<float>()(candidate_idx, 0);
|
|
||||||
|
|
||||||
// unused.
|
// unused.
|
||||||
int32 left_node_id;
|
int32 left_node_id;
|
||||||
int32 right_node_id;
|
int32 right_node_id;
|
||||||
|
|
||||||
// TODO(tanzheny): add categorical split.
|
boosted_trees::SplitTypeWithDefault split_type_with_default;
|
||||||
ensemble_resource->AddBucketizedSplitNode(
|
bool parsed = boosted_trees::SplitTypeWithDefault_Parse(
|
||||||
current_tree, node_id, feature_id, dimension_id, threshold, gain,
|
split_type, &split_type_with_default);
|
||||||
left_contrib, right_contrib, &left_node_id, &right_node_id);
|
DCHECK(parsed);
|
||||||
|
if (split_type_with_default == boosted_trees::EQUALITY_DEFAULT_RIGHT) {
|
||||||
|
// Add equality split to the node.
|
||||||
|
ensemble_resource->AddCategoricalSplitNode(
|
||||||
|
current_tree, split_entry, &left_node_id, &right_node_id);
|
||||||
|
} else {
|
||||||
|
// Add inequality split to the node.
|
||||||
|
ensemble_resource->AddBucketizedSplitNode(
|
||||||
|
current_tree, split_entry, &left_node_id, &right_node_id);
|
||||||
|
}
|
||||||
split_happened = true;
|
split_happened = true;
|
||||||
}
|
}
|
||||||
int32 node_id_end = ensemble_resource->GetNumNodes(current_tree);
|
int32 node_id_end = ensemble_resource->GetNumNodes(current_tree);
|
||||||
@ -408,32 +386,54 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
// Helper method which effectively does a reduce over all split candidates
|
// Helper method which effectively does a reduce over all split candidates
|
||||||
// and finds the best split for each node.
|
// and finds the best split for each node.
|
||||||
void FindBestSplitsPerNode(
|
void FindBestSplitsPerNode(
|
||||||
OpKernelContext* const context, const OpInputList& node_ids_list,
|
OpKernelContext* const context, const float learning_rate,
|
||||||
const OpInputList& gains_list,
|
const OpInputList& node_ids_list, const OpInputList& gains_list,
|
||||||
|
const OpInputList& thresholds_list, const OpInputList& dimension_ids_list,
|
||||||
|
const OpInputList& left_node_contribs_list,
|
||||||
|
const OpInputList& right_node_contribs_list,
|
||||||
|
const OpInputList& split_types_list,
|
||||||
const TTypes<const int32>::Vec& feature_ids,
|
const TTypes<const int32>::Vec& feature_ids,
|
||||||
std::map<int32, SplitCandidate>* best_split_per_node) {
|
std::map<int32, boosted_trees::SplitCandidate>* best_split_per_node) {
|
||||||
// Find best split per node going through every feature candidate.
|
// Find best split per node going through every feature candidate.
|
||||||
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
|
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
|
||||||
const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
|
const auto& node_ids = node_ids_list[feature_idx].vec<int32>();
|
||||||
const auto& gains = gains_list[feature_idx].vec<float>();
|
const auto& gains = gains_list[feature_idx].vec<float>();
|
||||||
|
const auto& thresholds = thresholds_list[feature_idx].vec<int32>();
|
||||||
|
const auto& dimension_ids = dimension_ids_list[feature_idx].vec<int32>();
|
||||||
|
const auto& left_node_contribs =
|
||||||
|
left_node_contribs_list[feature_idx].matrix<float>();
|
||||||
|
const auto& right_node_contribs =
|
||||||
|
right_node_contribs_list[feature_idx].matrix<float>();
|
||||||
|
const auto& split_types = split_types_list[feature_idx].vec<string>();
|
||||||
|
|
||||||
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
|
for (size_t candidate_idx = 0; candidate_idx < node_ids.size();
|
||||||
++candidate_idx) {
|
++candidate_idx) {
|
||||||
// Get current split candidate.
|
// Get current split candidate.
|
||||||
const auto& node_id = node_ids(candidate_idx);
|
const auto& node_id = node_ids(candidate_idx);
|
||||||
const auto& gain = gains(candidate_idx);
|
const auto& gain = gains(candidate_idx);
|
||||||
|
const auto& threshold = thresholds(candidate_idx);
|
||||||
|
const auto& dimension_id = dimension_ids(candidate_idx);
|
||||||
|
const auto& split_type = split_types(candidate_idx);
|
||||||
|
|
||||||
auto best_split_it = best_split_per_node->find(node_id);
|
auto best_split_it = best_split_per_node->find(node_id);
|
||||||
SplitCandidate candidate;
|
boosted_trees::SplitCandidate candidate;
|
||||||
candidate.feature_idx = feature_idx;
|
candidate.feature_idx = feature_ids(feature_idx);
|
||||||
candidate.candidate_idx = candidate_idx;
|
candidate.candidate_idx = candidate_idx;
|
||||||
candidate.gain = gain;
|
candidate.gain = gain;
|
||||||
|
candidate.threshold = threshold;
|
||||||
|
candidate.dimension_id = dimension_id;
|
||||||
|
// TODO(crawles): change here for multiclass.
|
||||||
|
candidate.left_node_contrib =
|
||||||
|
learning_rate * left_node_contribs(candidate_idx, 0);
|
||||||
|
candidate.right_node_contrib =
|
||||||
|
learning_rate * right_node_contribs(candidate_idx, 0);
|
||||||
|
candidate.split_type = split_type;
|
||||||
|
|
||||||
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
|
||||||
GainsAreEqual(gain, best_split_it->second.gain))) {
|
GainsAreEqual(gain, best_split_it->second.gain))) {
|
||||||
const auto best_candidate = (*best_split_per_node)[node_id];
|
const auto best_candidate = (*best_split_per_node)[node_id];
|
||||||
const int32 best_feature_id = feature_ids(best_candidate.feature_idx);
|
const int32 best_feature_id = best_candidate.feature_idx;
|
||||||
const int32 feature_id = feature_ids(candidate.feature_idx);
|
const int32 feature_id = candidate.feature_idx;
|
||||||
VLOG(2) << "Breaking ties on feature ids and buckets";
|
VLOG(2) << "Breaking ties on feature ids and buckets";
|
||||||
// Breaking ties deterministically.
|
// Breaking ties deterministically.
|
||||||
if (feature_id < best_feature_id) {
|
if (feature_id < best_feature_id) {
|
||||||
|
@ -25,6 +25,27 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace boosted_trees {
|
||||||
|
// TODO(nponomareva, youngheek): consider using vector.
|
||||||
|
struct SplitCandidate {
|
||||||
|
SplitCandidate() {}
|
||||||
|
|
||||||
|
// Index in the list of the feature ids.
|
||||||
|
int64 feature_idx = 0;
|
||||||
|
|
||||||
|
// Index in the tensor of node_ids for the feature with idx feature_idx.
|
||||||
|
int64 candidate_idx = 0;
|
||||||
|
|
||||||
|
float gain = 0.0;
|
||||||
|
int32 threshold = 0.0;
|
||||||
|
int32 dimension_id = 0;
|
||||||
|
float left_node_contrib = 0.0;
|
||||||
|
float right_node_contrib = 0.0;
|
||||||
|
// The split type, i.e., with missing value to left/right.
|
||||||
|
string split_type;
|
||||||
|
};
|
||||||
|
} // namespace boosted_trees
|
||||||
|
|
||||||
static bool GainsAreEqual(const float g1, const float g2) {
|
static bool GainsAreEqual(const float g1, const float g2) {
|
||||||
const float kTolerance = 1e-15;
|
const float kTolerance = 1e-15;
|
||||||
return std::abs(g1 - g2) < kTolerance;
|
return std::abs(g1 - g2) < kTolerance;
|
||||||
|
@ -28,10 +28,9 @@ from tensorflow.python.ops import sparse_ops
|
|||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
_INEQUALITY_DEFAULT_LEFT = 'inequality_default_left'.encode('utf-8')
|
_INEQUALITY_DEFAULT_LEFT = 'INEQUALITY_DEFAULT_LEFT'.encode('utf-8')
|
||||||
_INEQUALITY_DEFAULT_RIGHT = 'inequality_default_right'.encode('utf-8')
|
_INEQUALITY_DEFAULT_RIGHT = 'INEQUALITY_DEFAULT_RIGHT'.encode('utf-8')
|
||||||
_EQUALITY_DEFAULT_LEFT = 'equality_default_left'.encode('utf-8')
|
_EQUALITY_DEFAULT_RIGHT = 'EQUALITY_DEFAULT_RIGHT'.encode('utf-8')
|
||||||
_EQUALITY_DEFAULT_RIGHT = 'equality_default_right'.encode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
class StatsOpsTest(test_util.TensorFlowTestCase):
|
class StatsOpsTest(test_util.TensorFlowTestCase):
|
||||||
|
@ -26,6 +26,10 @@ from tensorflow.python.ops import boosted_trees_ops
|
|||||||
from tensorflow.python.ops import resources
|
from tensorflow.python.ops import resources
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
_INEQUALITY_DEFAULT_LEFT = 'INEQUALITY_DEFAULT_LEFT'.encode('utf-8')
|
||||||
|
_INEQUALITY_DEFAULT_RIGHT = 'INEQUALITY_DEFAULT_RIGHT'.encode('utf-8')
|
||||||
|
_EQUALITY_DEFAULT_RIGHT = 'EQUALITY_DEFAULT_RIGHT'.encode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||||
"""Tests for growing tree ensemble from split candidates."""
|
"""Tests for growing tree ensemble from split candidates."""
|
||||||
@ -158,7 +162,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
feature1_thresholds = np.array([52], dtype=np.int32)
|
feature1_thresholds = np.array([52], dtype=np.int32)
|
||||||
feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
|
feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
|
||||||
feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)
|
feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)
|
||||||
feature1_inequality_split_types = np.array(['inequality_default_left'])
|
feature1_inequality_split_types = np.array([_INEQUALITY_DEFAULT_LEFT])
|
||||||
|
|
||||||
# Feature split with the highest gain.
|
# Feature split with the highest gain.
|
||||||
feature2_nodes = np.array([0], dtype=np.int32)
|
feature2_nodes = np.array([0], dtype=np.int32)
|
||||||
@ -167,7 +171,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
feature2_thresholds = np.array([7], dtype=np.int32)
|
feature2_thresholds = np.array([7], dtype=np.int32)
|
||||||
feature2_left_node_contribs = np.array([[-4.89]], dtype=np.float32)
|
feature2_left_node_contribs = np.array([[-4.89]], dtype=np.float32)
|
||||||
feature2_right_node_contribs = np.array([[5.3]], dtype=np.float32)
|
feature2_right_node_contribs = np.array([[5.3]], dtype=np.float32)
|
||||||
feature2_inequality_split_types = np.array(['inequality_default_right'])
|
feature2_inequality_split_types = np.array([_INEQUALITY_DEFAULT_RIGHT])
|
||||||
|
|
||||||
# Grow tree ensemble.
|
# Grow tree ensemble.
|
||||||
grow_op = boosted_trees_ops.update_ensemble_v2(
|
grow_op = boosted_trees_ops.update_ensemble_v2(
|
||||||
@ -208,6 +212,116 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
dimension_id: 1
|
dimension_id: 1
|
||||||
left_id: 1
|
left_id: 1
|
||||||
right_id: 2
|
right_id: 2
|
||||||
|
default_direction: DEFAULT_RIGHT
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: 7.65
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: -0.489
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.53
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 1.0
|
||||||
|
tree_weights: 1.0
|
||||||
|
tree_metadata {
|
||||||
|
num_layers_grown: 1
|
||||||
|
is_finalized: true
|
||||||
|
}
|
||||||
|
tree_metadata {
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 1
|
||||||
|
num_layers_attempted: 1
|
||||||
|
last_layer_node_start: 0
|
||||||
|
last_layer_node_end: 1
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self.assertEqual(new_stamp, 1)
|
||||||
|
self.assertProtoEquals(expected_result, tree_ensemble)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGrowWithEmptyEnsembleV2EqualitySplit(self):
|
||||||
|
"""Test growing an empty ensemble."""
|
||||||
|
with self.cached_session() as session:
|
||||||
|
# Create empty ensemble.
|
||||||
|
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
|
||||||
|
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||||
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
|
|
||||||
|
feature_ids = [0, 6]
|
||||||
|
|
||||||
|
# Prepare feature inputs.
|
||||||
|
feature1_nodes = np.array([0], dtype=np.int32)
|
||||||
|
feature1_gains = np.array([7.62], dtype=np.float32)
|
||||||
|
feature1_dimensions = np.array([0], dtype=np.int32)
|
||||||
|
feature1_thresholds = np.array([52], dtype=np.int32)
|
||||||
|
feature1_left_node_contribs = np.array([[-4.375]], dtype=np.float32)
|
||||||
|
feature1_right_node_contribs = np.array([[7.143]], dtype=np.float32)
|
||||||
|
feature1_inequality_split_types = np.array([_INEQUALITY_DEFAULT_LEFT])
|
||||||
|
|
||||||
|
# Feature split with the highest gain.
|
||||||
|
feature2_nodes = np.array([0], dtype=np.int32)
|
||||||
|
feature2_gains = np.array([7.65], dtype=np.float32)
|
||||||
|
feature2_dimensions = np.array([1], dtype=np.int32)
|
||||||
|
feature2_thresholds = np.array([7], dtype=np.int32)
|
||||||
|
feature2_left_node_contribs = np.array([[-4.89]], dtype=np.float32)
|
||||||
|
feature2_right_node_contribs = np.array([[5.3]], dtype=np.float32)
|
||||||
|
feature2_inequality_split_types = np.array([_EQUALITY_DEFAULT_RIGHT])
|
||||||
|
|
||||||
|
# Grow tree ensemble.
|
||||||
|
grow_op = boosted_trees_ops.update_ensemble_v2(
|
||||||
|
tree_ensemble_handle,
|
||||||
|
learning_rate=0.1,
|
||||||
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
|
# Tree will be finalized now, since we will reach depth 1.
|
||||||
|
max_depth=1,
|
||||||
|
feature_ids=feature_ids,
|
||||||
|
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
||||||
|
node_ids=[feature1_nodes, feature2_nodes],
|
||||||
|
gains=[feature1_gains, feature2_gains],
|
||||||
|
thresholds=[feature1_thresholds, feature2_thresholds],
|
||||||
|
left_node_contribs=[
|
||||||
|
feature1_left_node_contribs, feature2_left_node_contribs
|
||||||
|
],
|
||||||
|
right_node_contribs=[
|
||||||
|
feature1_right_node_contribs, feature2_right_node_contribs
|
||||||
|
],
|
||||||
|
split_types=[
|
||||||
|
feature1_inequality_split_types, feature2_inequality_split_types
|
||||||
|
],
|
||||||
|
)
|
||||||
|
session.run(grow_op)
|
||||||
|
|
||||||
|
new_stamp, serialized = session.run(tree_ensemble.serialize())
|
||||||
|
|
||||||
|
tree_ensemble = boosted_trees_pb2.TreeEnsemble()
|
||||||
|
tree_ensemble.ParseFromString(serialized)
|
||||||
|
|
||||||
|
# Note that since the tree is finalized, we added a new dummy tree.
|
||||||
|
expected_result = """
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
categorical_split {
|
||||||
|
feature_id: 6
|
||||||
|
value: 7
|
||||||
|
dimension_id: 1
|
||||||
|
left_id: 1
|
||||||
|
right_id: 2
|
||||||
}
|
}
|
||||||
metadata {
|
metadata {
|
||||||
gain: 7.65
|
gain: 7.65
|
||||||
@ -537,7 +651,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
feature1_thresholds = np.array([21], dtype=np.int32)
|
feature1_thresholds = np.array([21], dtype=np.int32)
|
||||||
feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
|
feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
|
||||||
feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
|
feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
|
||||||
feature1_split_types = np.array(['inequality_default_left'])
|
feature1_split_types = np.array([_INEQUALITY_DEFAULT_LEFT])
|
||||||
|
|
||||||
feature2_nodes = np.array([1, 2], dtype=np.int32)
|
feature2_nodes = np.array([1, 2], dtype=np.int32)
|
||||||
feature2_gains = np.array([0.63, 2.7], dtype=np.float32)
|
feature2_gains = np.array([0.63, 2.7], dtype=np.float32)
|
||||||
@ -546,7 +660,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
|
feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
|
||||||
feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
|
feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
|
||||||
feature2_split_types = np.array(
|
feature2_split_types = np.array(
|
||||||
['inequality_default_right', 'inequality_default_right'])
|
[_INEQUALITY_DEFAULT_RIGHT, _INEQUALITY_DEFAULT_RIGHT])
|
||||||
|
|
||||||
feature3_nodes = np.array([2], dtype=np.int32)
|
feature3_nodes = np.array([2], dtype=np.int32)
|
||||||
feature3_gains = np.array([1.7], dtype=np.float32)
|
feature3_gains = np.array([1.7], dtype=np.float32)
|
||||||
@ -554,7 +668,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
feature3_thresholds = np.array([3], dtype=np.int32)
|
feature3_thresholds = np.array([3], dtype=np.int32)
|
||||||
feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
|
feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
|
||||||
feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
|
feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
|
||||||
feature3_split_types = np.array(['inequality_default_left'])
|
feature3_split_types = np.array([_INEQUALITY_DEFAULT_LEFT])
|
||||||
|
|
||||||
# Grow tree ensemble.
|
# Grow tree ensemble.
|
||||||
grow_op = boosted_trees_ops.update_ensemble_v2(
|
grow_op = boosted_trees_ops.update_ensemble_v2(
|
||||||
@ -629,6 +743,212 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
threshold: 7
|
threshold: 7
|
||||||
left_id: 5
|
left_id: 5
|
||||||
right_id: 6
|
right_id: 6
|
||||||
|
default_direction: DEFAULT_RIGHT
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: 2.7
|
||||||
|
original_leaf {
|
||||||
|
scalar: -0.4375
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.114
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.879
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: -0.5875
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: -0.2075
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 1.0
|
||||||
|
tree_weights: 1.0
|
||||||
|
tree_metadata {
|
||||||
|
is_finalized: true
|
||||||
|
num_layers_grown: 2
|
||||||
|
}
|
||||||
|
tree_metadata {
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 1
|
||||||
|
num_layers_attempted: 2
|
||||||
|
last_layer_node_start: 0
|
||||||
|
last_layer_node_end: 1
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self.assertEqual(new_stamp, 1)
|
||||||
|
self.assertProtoEquals(expected_result, tree_ensemble)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGrowExistingEnsembleTreeV2NotFinalizedEqualitySplit(self):
|
||||||
|
"""Test growing an existing ensemble with the last tree not finalized."""
|
||||||
|
with self.cached_session() as session:
|
||||||
|
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||||
|
text_format.Merge(
|
||||||
|
"""
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
bucketized_split {
|
||||||
|
feature_id: 4
|
||||||
|
dimension_id: 0
|
||||||
|
left_id: 1
|
||||||
|
right_id: 2
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: 7.62
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.714
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: -0.4375
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 1.0
|
||||||
|
tree_metadata {
|
||||||
|
num_layers_grown: 1
|
||||||
|
is_finalized: false
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 1
|
||||||
|
num_layers_attempted: 1
|
||||||
|
}
|
||||||
|
""", tree_ensemble_config)
|
||||||
|
|
||||||
|
# Create existing ensemble with one root split
|
||||||
|
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||||
|
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||||
|
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||||
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
|
|
||||||
|
# Prepare feature inputs.
|
||||||
|
# feature 1 only has a candidate for node 1, feature 2 has candidates
|
||||||
|
# for both nodes and feature 3 only has a candidate for node 2.
|
||||||
|
|
||||||
|
feature_ids = [0, 1, 0]
|
||||||
|
|
||||||
|
feature1_nodes = np.array([1], dtype=np.int32)
|
||||||
|
feature1_gains = np.array([1.4], dtype=np.float32)
|
||||||
|
feature1_dimensions = np.array([0], dtype=np.int32)
|
||||||
|
feature1_thresholds = np.array([21], dtype=np.int32)
|
||||||
|
feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
|
||||||
|
feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
|
||||||
|
feature1_split_types = np.array([_INEQUALITY_DEFAULT_LEFT])
|
||||||
|
|
||||||
|
feature2_nodes = np.array([1, 2], dtype=np.int32)
|
||||||
|
feature2_gains = np.array([0.63, 2.7], dtype=np.float32)
|
||||||
|
feature2_dimensions = np.array([1, 3], dtype=np.int32)
|
||||||
|
feature2_thresholds = np.array([23, 7], dtype=np.int32)
|
||||||
|
feature2_left_node_contribs = np.array([[-0.6], [-1.5]], dtype=np.float32)
|
||||||
|
feature2_right_node_contribs = np.array([[0.24], [2.3]], dtype=np.float32)
|
||||||
|
feature2_split_types = np.array(
|
||||||
|
[_EQUALITY_DEFAULT_RIGHT, _EQUALITY_DEFAULT_RIGHT])
|
||||||
|
|
||||||
|
feature3_nodes = np.array([2], dtype=np.int32)
|
||||||
|
feature3_gains = np.array([1.7], dtype=np.float32)
|
||||||
|
feature3_dimensions = np.array([0], dtype=np.int32)
|
||||||
|
feature3_thresholds = np.array([3], dtype=np.int32)
|
||||||
|
feature3_left_node_contribs = np.array([[-0.75]], dtype=np.float32)
|
||||||
|
feature3_right_node_contribs = np.array([[1.93]], dtype=np.float32)
|
||||||
|
feature3_split_types = np.array([_INEQUALITY_DEFAULT_LEFT])
|
||||||
|
|
||||||
|
# Grow tree ensemble.
|
||||||
|
grow_op = boosted_trees_ops.update_ensemble_v2(
|
||||||
|
tree_ensemble_handle,
|
||||||
|
learning_rate=0.1,
|
||||||
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
|
# tree is going to be finalized now, since we reach depth 2.
|
||||||
|
max_depth=2,
|
||||||
|
feature_ids=feature_ids,
|
||||||
|
dimension_ids=[
|
||||||
|
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
||||||
|
],
|
||||||
|
node_ids=[feature1_nodes, feature2_nodes, feature3_nodes],
|
||||||
|
gains=[feature1_gains, feature2_gains, feature3_gains],
|
||||||
|
thresholds=[
|
||||||
|
feature1_thresholds, feature2_thresholds, feature3_thresholds
|
||||||
|
],
|
||||||
|
left_node_contribs=[
|
||||||
|
feature1_left_node_contribs, feature2_left_node_contribs,
|
||||||
|
feature3_left_node_contribs
|
||||||
|
],
|
||||||
|
right_node_contribs=[
|
||||||
|
feature1_right_node_contribs, feature2_right_node_contribs,
|
||||||
|
feature3_right_node_contribs
|
||||||
|
],
|
||||||
|
split_types=[
|
||||||
|
feature1_split_types, feature2_split_types, feature3_split_types
|
||||||
|
],
|
||||||
|
)
|
||||||
|
session.run(grow_op)
|
||||||
|
|
||||||
|
# Expect the split for node 1 to be chosen from feature 1 and
|
||||||
|
# the split for node 2 to be chosen from feature 2.
|
||||||
|
# The grown tree should be finalized as max tree depth is 2 and we have
|
||||||
|
# grown 2 layers.
|
||||||
|
new_stamp, serialized = session.run(tree_ensemble.serialize())
|
||||||
|
tree_ensemble = boosted_trees_pb2.TreeEnsemble()
|
||||||
|
tree_ensemble.ParseFromString(serialized)
|
||||||
|
|
||||||
|
expected_result = """
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
bucketized_split {
|
||||||
|
feature_id: 4
|
||||||
|
dimension_id: 0
|
||||||
|
left_id: 1
|
||||||
|
right_id: 2
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: 7.62
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
bucketized_split {
|
||||||
|
feature_id: 0
|
||||||
|
threshold: 21
|
||||||
|
dimension_id: 0
|
||||||
|
left_id: 3
|
||||||
|
right_id: 4
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: 1.4
|
||||||
|
original_leaf {
|
||||||
|
scalar: 0.714
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
categorical_split {
|
||||||
|
feature_id: 1
|
||||||
|
dimension_id: 3
|
||||||
|
value: 7
|
||||||
|
left_id: 5
|
||||||
|
right_id: 6
|
||||||
}
|
}
|
||||||
metadata {
|
metadata {
|
||||||
gain: 2.7
|
gain: 2.7
|
||||||
@ -900,7 +1220,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
feature1_thresholds = np.array([21], dtype=np.int32)
|
feature1_thresholds = np.array([21], dtype=np.int32)
|
||||||
feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
|
feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
|
||||||
feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
|
feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
|
||||||
feature1_split_types = np.array(['inequality_default_right'])
|
feature1_split_types = np.array([_INEQUALITY_DEFAULT_RIGHT])
|
||||||
|
|
||||||
# Grow tree ensemble.
|
# Grow tree ensemble.
|
||||||
grow_op = boosted_trees_ops.update_ensemble_v2(
|
grow_op = boosted_trees_ops.update_ensemble_v2(
|
||||||
@ -955,6 +1275,165 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
threshold: 21
|
threshold: 21
|
||||||
left_id: 1
|
left_id: 1
|
||||||
right_id: 2
|
right_id: 2
|
||||||
|
default_direction: DEFAULT_RIGHT
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: -1.4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: -0.6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.165
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 0.15
|
||||||
|
tree_weights: 1.0
|
||||||
|
tree_metadata {
|
||||||
|
num_layers_grown: 1
|
||||||
|
is_finalized: true
|
||||||
|
}
|
||||||
|
tree_metadata {
|
||||||
|
num_layers_grown: 1
|
||||||
|
is_finalized: false
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 2
|
||||||
|
num_layers_attempted: 2
|
||||||
|
last_layer_node_start: 1
|
||||||
|
last_layer_node_end: 3
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self.assertEqual(new_stamp, 1)
|
||||||
|
self.assertProtoEquals(expected_result, tree_ensemble)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGrowExistingEnsembleTreeV2FinalizedEqualitySplit(self):
|
||||||
|
"""Test growing an existing ensemble with the last tree finalized."""
|
||||||
|
with self.cached_session() as session:
|
||||||
|
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
|
||||||
|
text_format.Merge(
|
||||||
|
"""
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
bucketized_split {
|
||||||
|
feature_id: 4
|
||||||
|
dimension_id: 0
|
||||||
|
left_id: 1
|
||||||
|
right_id: 2
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: 7.62
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 7.14
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: -4.375
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tree_weights: 0.15
|
||||||
|
tree_weights: 1.0
|
||||||
|
tree_metadata {
|
||||||
|
num_layers_grown: 1
|
||||||
|
is_finalized: true
|
||||||
|
}
|
||||||
|
tree_metadata {
|
||||||
|
}
|
||||||
|
growing_metadata {
|
||||||
|
num_trees_attempted: 1
|
||||||
|
num_layers_attempted: 1
|
||||||
|
}
|
||||||
|
""", tree_ensemble_config)
|
||||||
|
|
||||||
|
# Create existing ensemble with one root split
|
||||||
|
tree_ensemble = boosted_trees_ops.TreeEnsemble(
|
||||||
|
'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
|
||||||
|
tree_ensemble_handle = tree_ensemble.resource_handle
|
||||||
|
resources.initialize_resources(resources.shared_resources()).run()
|
||||||
|
|
||||||
|
# Prepare feature inputs.
|
||||||
|
|
||||||
|
feature_ids = [75]
|
||||||
|
|
||||||
|
feature1_nodes = np.array([0], dtype=np.int32)
|
||||||
|
feature1_gains = np.array([-1.4], dtype=np.float32)
|
||||||
|
feature1_dimensions = np.array([1], dtype=np.int32)
|
||||||
|
feature1_thresholds = np.array([21], dtype=np.int32)
|
||||||
|
feature1_left_node_contribs = np.array([[-6.0]], dtype=np.float32)
|
||||||
|
feature1_right_node_contribs = np.array([[1.65]], dtype=np.float32)
|
||||||
|
feature1_split_types = np.array([_EQUALITY_DEFAULT_RIGHT])
|
||||||
|
|
||||||
|
# Grow tree ensemble.
|
||||||
|
grow_op = boosted_trees_ops.update_ensemble_v2(
|
||||||
|
tree_ensemble_handle,
|
||||||
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
|
learning_rate=0.1,
|
||||||
|
max_depth=2,
|
||||||
|
feature_ids=feature_ids,
|
||||||
|
dimension_ids=[feature1_dimensions],
|
||||||
|
node_ids=[feature1_nodes],
|
||||||
|
gains=[feature1_gains],
|
||||||
|
thresholds=[feature1_thresholds],
|
||||||
|
left_node_contribs=[feature1_left_node_contribs],
|
||||||
|
right_node_contribs=[feature1_right_node_contribs],
|
||||||
|
split_types=[feature1_split_types])
|
||||||
|
session.run(grow_op)
|
||||||
|
|
||||||
|
# Expect a new tree added, with a split on feature 75
|
||||||
|
new_stamp, serialized = session.run(tree_ensemble.serialize())
|
||||||
|
tree_ensemble = boosted_trees_pb2.TreeEnsemble()
|
||||||
|
tree_ensemble.ParseFromString(serialized)
|
||||||
|
|
||||||
|
expected_result = """
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
bucketized_split {
|
||||||
|
feature_id: 4
|
||||||
|
dimension_id: 0
|
||||||
|
left_id: 1
|
||||||
|
right_id: 2
|
||||||
|
}
|
||||||
|
metadata {
|
||||||
|
gain: 7.62
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: 7.14
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
leaf {
|
||||||
|
scalar: -4.375
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
trees {
|
||||||
|
nodes {
|
||||||
|
categorical_split {
|
||||||
|
feature_id: 75
|
||||||
|
dimension_id: 1
|
||||||
|
value: 21
|
||||||
|
left_id: 1
|
||||||
|
right_id: 2
|
||||||
}
|
}
|
||||||
metadata {
|
metadata {
|
||||||
gain: -1.4
|
gain: -1.4
|
||||||
|
Loading…
Reference in New Issue
Block a user