Support empty ensemble with vector leafs.

PiperOrigin-RevId: 268472135
This commit is contained in:
A. Unique TensorFlower 2019-09-11 09:12:13 -07:00 committed by TensorFlower Gardener
parent c767f107c2
commit 680b6d2576
5 changed files with 95 additions and 15 deletions

View File

@ -253,15 +253,24 @@ void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
}
// Add a tree to the ensemble and returns a new tree_id.
int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
return AddNewTreeWithLogits(weight, 0.0);
int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
const int32 logits_dimension) {
const std::vector<float> empty_leaf(logits_dimension);
return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
}
int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
const float logits) {
int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
const float weight, const std::vector<float>& logits,
const int32 logits_dimension) {
const int32 new_tree_id = tree_ensemble_->trees_size();
auto* node = tree_ensemble_->add_trees()->add_nodes();
node->mutable_leaf()->set_scalar(logits);
if (logits_dimension == 1) {
node->mutable_leaf()->set_scalar(logits[0]);
} else {
for (int32 i = 0; i < logits_dimension; ++i) {
node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
}
}
tree_ensemble_->add_tree_weights(weight);
tree_ensemble_->add_tree_metadata();

View File

@ -102,10 +102,12 @@ class BoostedTreesEnsembleResource : public StampedResource {
int32 right_id(const int32 tree_id, const int32 node_id) const;
// Add a tree to the ensemble and returns a new tree_id.
int32 AddNewTree(const float weight);
int32 AddNewTree(const float weight, const int32 logits_dimension);
// Adds new tree with one node to the ensemble and sets node's value to logits
int32 AddNewTreeWithLogits(const float weight, const float logits);
int32 AddNewTreeWithLogits(const float weight,
const std::vector<float>& logits,
const int32 logits_dimension);
// Grows the tree by adding a bucketized split and leaves.
void AddBucketizedSplitNode(

View File

@ -136,7 +136,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
}
if (ensemble_resource->num_trees() > 0) {
// Create a dummy new tree with an empty node.
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, 1);
}
}
// If we managed to split, update the node range. If we didn't, don't
@ -159,7 +159,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// boosting.
if (num_trees <= 0) {
// Create a new tree with a no-op leaf.
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, 1);
}
return current_tree;
}
@ -357,7 +357,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
}
if (ensemble_resource->num_trees() > 0) {
// Create a dummy new tree with an empty node.
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
}
}
// If we managed to split, update the node range. If we didn't, don't
@ -380,7 +380,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
// boosting.
if (num_trees <= 0) {
// Create a new tree with a no-op leaf.
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
}
return current_tree;
}
@ -504,7 +504,8 @@ class BoostedTreesCenterBiasOp : public OpKernel {
float current_bias = 0.0;
bool continue_centering = true;
if (ensemble_resource->num_trees() == 0) {
ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits);
ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, {logits},
1);
current_bias = logits;
} else {
const auto& current_biases = ensemble_resource->node_value(0, 0);

View File

@ -2459,9 +2459,20 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
}
}
}
trees {
nodes {
leaf {
vector {
value: 0
value: 0
}
}
}
}
tree_weights: 0.1
tree_weights: 0.2
tree_weights: 1.0
tree_weights: 1.0
""", tree_ensemble_config)
# Create existing ensemble with one root split

View File

@ -438,6 +438,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
}
metadata {
gain: 7.65
original_leaf {
vector {
value: 0.0
value: 0.0
}
}
}
}
nodes {
@ -464,7 +470,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
trees {
nodes {
leaf {
scalar: 0.0
vector {
value: 0.0
value: 0.0
}
}
}
}
@ -1103,7 +1112,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
trees {
nodes {
leaf {
scalar: 0.0
}
}
}
@ -1352,7 +1360,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
trees {
nodes {
leaf {
scalar: 0.0
vector {
value: 0.0
value: 0.0
}
}
}
}
@ -3065,6 +3076,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
}
metadata {
gain: -0.2
original_leaf {
vector {
value: 0.0
value: 0.0
}
}
}
}
nodes {
@ -3152,6 +3169,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
}
metadata {
gain: -0.2
original_leaf {
vector {
value: 0.0
value: 0.0
}
}
}
}
nodes {
@ -3301,6 +3324,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
}
metadata {
gain: -0.2
original_leaf {
vector {
value: 0.0
value: 0.0
}
}
}
}
nodes {
@ -3357,6 +3386,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
trees {
nodes {
leaf {
vector {
value: 0
value: 0
}
}
}
}
@ -3682,6 +3715,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
}
metadata {
gain: -0.62
original_leaf {
vector {
value: 0.0
value: 0.0
}
}
}
}
nodes {
@ -3766,12 +3805,20 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
trees {
nodes {
leaf {
vector {
value: 0
value: 0
}
}
}
}
trees {
nodes {
leaf {
vector {
value: 0
value: 0
}
}
}
}
@ -4000,6 +4047,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
}
metadata {
gain: 7.62
original_leaf {
vector {
value: 0.0
value: 0.0
}
}
}
}
nodes {
@ -4026,6 +4079,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
trees {
nodes {
leaf {
vector {
value: 0
value: 0
}
}
}
}