Support empty ensemble with vector leafs.
PiperOrigin-RevId: 268472135
This commit is contained in:
parent
c767f107c2
commit
680b6d2576
@ -253,15 +253,24 @@ void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
|
||||
}
|
||||
|
||||
// Add a tree to the ensemble and returns a new tree_id.
|
||||
int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
|
||||
return AddNewTreeWithLogits(weight, 0.0);
|
||||
int32 BoostedTreesEnsembleResource::AddNewTree(const float weight,
|
||||
const int32 logits_dimension) {
|
||||
const std::vector<float> empty_leaf(logits_dimension);
|
||||
return AddNewTreeWithLogits(weight, empty_leaf, logits_dimension);
|
||||
}
|
||||
|
||||
int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
|
||||
const float logits) {
|
||||
int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(
|
||||
const float weight, const std::vector<float>& logits,
|
||||
const int32 logits_dimension) {
|
||||
const int32 new_tree_id = tree_ensemble_->trees_size();
|
||||
auto* node = tree_ensemble_->add_trees()->add_nodes();
|
||||
node->mutable_leaf()->set_scalar(logits);
|
||||
if (logits_dimension == 1) {
|
||||
node->mutable_leaf()->set_scalar(logits[0]);
|
||||
} else {
|
||||
for (int32 i = 0; i < logits_dimension; ++i) {
|
||||
node->mutable_leaf()->mutable_vector()->add_value(logits[i]);
|
||||
}
|
||||
}
|
||||
tree_ensemble_->add_tree_weights(weight);
|
||||
tree_ensemble_->add_tree_metadata();
|
||||
|
||||
|
@ -102,10 +102,12 @@ class BoostedTreesEnsembleResource : public StampedResource {
|
||||
int32 right_id(const int32 tree_id, const int32 node_id) const;
|
||||
|
||||
// Add a tree to the ensemble and returns a new tree_id.
|
||||
int32 AddNewTree(const float weight);
|
||||
int32 AddNewTree(const float weight, const int32 logits_dimension);
|
||||
|
||||
// Adds new tree with one node to the ensemble and sets node's value to logits
|
||||
int32 AddNewTreeWithLogits(const float weight, const float logits);
|
||||
int32 AddNewTreeWithLogits(const float weight,
|
||||
const std::vector<float>& logits,
|
||||
const int32 logits_dimension);
|
||||
|
||||
// Grows the tree by adding a bucketized split and leaves.
|
||||
void AddBucketizedSplitNode(
|
||||
|
@ -136,7 +136,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
||||
}
|
||||
if (ensemble_resource->num_trees() > 0) {
|
||||
// Create a dummy new tree with an empty node.
|
||||
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
|
||||
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, 1);
|
||||
}
|
||||
}
|
||||
// If we managed to split, update the node range. If we didn't, don't
|
||||
@ -159,7 +159,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
|
||||
// boosting.
|
||||
if (num_trees <= 0) {
|
||||
// Create a new tree with a no-op leaf.
|
||||
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
|
||||
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, 1);
|
||||
}
|
||||
return current_tree;
|
||||
}
|
||||
@ -357,7 +357,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
}
|
||||
if (ensemble_resource->num_trees() > 0) {
|
||||
// Create a dummy new tree with an empty node.
|
||||
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight);
|
||||
ensemble_resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
|
||||
}
|
||||
}
|
||||
// If we managed to split, update the node range. If we didn't, don't
|
||||
@ -380,7 +380,7 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
||||
// boosting.
|
||||
if (num_trees <= 0) {
|
||||
// Create a new tree with a no-op leaf.
|
||||
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight);
|
||||
current_tree = resource->AddNewTree(kLayerByLayerTreeWeight, logits_dim_);
|
||||
}
|
||||
return current_tree;
|
||||
}
|
||||
@ -504,7 +504,8 @@ class BoostedTreesCenterBiasOp : public OpKernel {
|
||||
float current_bias = 0.0;
|
||||
bool continue_centering = true;
|
||||
if (ensemble_resource->num_trees() == 0) {
|
||||
ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits);
|
||||
ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, {logits},
|
||||
1);
|
||||
current_bias = logits;
|
||||
} else {
|
||||
const auto& current_biases = ensemble_resource->node_value(0, 0);
|
||||
|
@ -2459,9 +2459,20 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
}
|
||||
}
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0
|
||||
value: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_weights: 0.1
|
||||
tree_weights: 0.2
|
||||
tree_weights: 1.0
|
||||
tree_weights: 1.0
|
||||
""", tree_ensemble_config)
|
||||
|
||||
# Create existing ensemble with one root split
|
||||
|
@ -438,6 +438,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
metadata {
|
||||
gain: 7.65
|
||||
original_leaf {
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
@ -464,7 +470,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
scalar: 0.0
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1103,7 +1112,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
scalar: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1352,7 +1360,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
scalar: 0.0
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3065,6 +3076,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
metadata {
|
||||
gain: -0.2
|
||||
original_leaf {
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
@ -3152,6 +3169,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
metadata {
|
||||
gain: -0.2
|
||||
original_leaf {
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
@ -3301,6 +3324,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
metadata {
|
||||
gain: -0.2
|
||||
original_leaf {
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
@ -3357,6 +3386,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0
|
||||
value: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3682,6 +3715,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
metadata {
|
||||
gain: -0.62
|
||||
original_leaf {
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
@ -3766,12 +3805,20 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0
|
||||
value: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0
|
||||
value: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4000,6 +4047,12 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
}
|
||||
metadata {
|
||||
gain: 7.62
|
||||
original_leaf {
|
||||
vector {
|
||||
value: 0.0
|
||||
value: 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nodes {
|
||||
@ -4026,6 +4079,10 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
||||
trees {
|
||||
nodes {
|
||||
leaf {
|
||||
vector {
|
||||
value: 0
|
||||
value: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user