Backward compatible api change:BoostedTreesUpdateEnsembleV2 works on list of feature_ids.
PiperOrigin-RevId: 289687663 Change-Id: I5d12d044ae42fc34f03a3eaa357bf71b7cb06eec
This commit is contained in:
parent
857b55bc2d
commit
8a2a86318b
@ -91,6 +91,14 @@ END
|
|||||||
name: "logits_dimension"
|
name: "logits_dimension"
|
||||||
description: <<END
|
description: <<END
|
||||||
scalar, dimension of the logits
|
scalar, dimension of the logits
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "num_groups"
|
||||||
|
description: <<END
|
||||||
|
Number of groups of split information to process, where a group contains feature
|
||||||
|
ids that are processed together in BoostedTreesCalculateBestFeatureSplitOpV2.
|
||||||
|
INFERRED.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
summary: "Updates the tree ensemble by adding a layer to the last tree being grown"
|
summary: "Updates the tree ensemble by adding a layer to the last tree being grown"
|
||||||
|
|||||||
@ -269,9 +269,11 @@ class BoostedTreesUpdateEnsembleV2Op : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->input_list("split_types", &split_types_list));
|
context->input_list("split_types", &split_types_list));
|
||||||
|
|
||||||
const Tensor* feature_ids_t;
|
OpInputList feature_ids_list;
|
||||||
OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t));
|
OP_REQUIRES_OK(context,
|
||||||
const auto feature_ids = feature_ids_t->vec<int32>();
|
context->input_list("feature_ids", &feature_ids_list));
|
||||||
|
// TODO(crawles): Read groups of feature ids and find best splits among all.
|
||||||
|
const auto feature_ids = feature_ids_list[0].vec<int32>();
|
||||||
|
|
||||||
const Tensor* max_depth_t;
|
const Tensor* max_depth_t;
|
||||||
OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
|
OP_REQUIRES_OK(context, context->input("max_depth", &max_depth_t));
|
||||||
|
|||||||
@ -618,7 +618,7 @@ REGISTER_OP("BoostedTreesUpdateEnsemble")
|
|||||||
|
|
||||||
REGISTER_OP("BoostedTreesUpdateEnsembleV2")
|
REGISTER_OP("BoostedTreesUpdateEnsembleV2")
|
||||||
.Input("tree_ensemble_handle: resource")
|
.Input("tree_ensemble_handle: resource")
|
||||||
.Input("feature_ids: int32")
|
.Input("feature_ids: num_groups * int32")
|
||||||
.Input("dimension_ids: num_features * int32")
|
.Input("dimension_ids: num_features * int32")
|
||||||
.Input("node_ids: num_features * int32")
|
.Input("node_ids: num_features * int32")
|
||||||
.Input("gains: num_features * float")
|
.Input("gains: num_features * float")
|
||||||
@ -631,13 +631,18 @@ REGISTER_OP("BoostedTreesUpdateEnsembleV2")
|
|||||||
.Input("pruning_mode: int32")
|
.Input("pruning_mode: int32")
|
||||||
.Attr("num_features: int >= 0") // Inferred.
|
.Attr("num_features: int >= 0") // Inferred.
|
||||||
.Attr("logits_dimension: int = 1")
|
.Attr("logits_dimension: int = 1")
|
||||||
|
.Attr("num_groups: int = 1") // Number of groups to process.
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
shape_inference::ShapeHandle shape_handle;
|
shape_inference::ShapeHandle shape_handle;
|
||||||
int num_features;
|
int num_features;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
|
TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
|
||||||
|
int num_groups;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("num_groups", &num_groups));
|
||||||
|
|
||||||
// Feature_ids, should be one for each feature.
|
// Feature_ids, should be one for each feature.
|
||||||
shape_inference::ShapeHandle feature_ids_shape;
|
shape_inference::ShapeHandle feature_ids_shape;
|
||||||
|
// TODO(crawles): remove 1 hardcode once kernel operates on multiple
|
||||||
|
// groups.
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &feature_ids_shape));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
|
c->Merge(c->input(1), c->Vector(num_features), &shape_handle));
|
||||||
|
|||||||
@ -180,7 +180,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
# Tree will be finalized now, since we will reach depth 1.
|
# Tree will be finalized now, since we will reach depth 1.
|
||||||
max_depth=1,
|
max_depth=1,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
||||||
node_ids=[feature1_nodes, feature2_nodes],
|
node_ids=[feature1_nodes, feature2_nodes],
|
||||||
gains=[feature1_gains, feature2_gains],
|
gains=[feature1_gains, feature2_gains],
|
||||||
@ -289,7 +289,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
# Tree will be finalized now, since we will reach depth 1.
|
# Tree will be finalized now, since we will reach depth 1.
|
||||||
max_depth=1,
|
max_depth=1,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
||||||
node_ids=[feature1_nodes, feature2_nodes],
|
node_ids=[feature1_nodes, feature2_nodes],
|
||||||
gains=[feature1_gains, feature2_gains],
|
gains=[feature1_gains, feature2_gains],
|
||||||
@ -401,7 +401,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
# Tree will be finalized now, since we will reach depth 1.
|
# Tree will be finalized now, since we will reach depth 1.
|
||||||
max_depth=1,
|
max_depth=1,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
||||||
node_ids=[feature1_nodes, feature2_nodes],
|
node_ids=[feature1_nodes, feature2_nodes],
|
||||||
gains=[feature1_gains, feature2_gains],
|
gains=[feature1_gains, feature2_gains],
|
||||||
@ -809,7 +809,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
# tree is going to be finalized now, since we reach depth 2.
|
# tree is going to be finalized now, since we reach depth 2.
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[
|
dimension_ids=[
|
||||||
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
||||||
],
|
],
|
||||||
@ -1014,7 +1014,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
# tree is going to be finalized now, since we reach depth 2.
|
# tree is going to be finalized now, since we reach depth 2.
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[
|
dimension_ids=[
|
||||||
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
||||||
],
|
],
|
||||||
@ -1230,7 +1230,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
# tree is going to be finalized now, since we reach depth 2.
|
# tree is going to be finalized now, since we reach depth 2.
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[
|
dimension_ids=[
|
||||||
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
||||||
],
|
],
|
||||||
@ -1610,7 +1610,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
learning_rate=0.1,
|
learning_rate=0.1,
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions],
|
dimension_ids=[feature1_dimensions],
|
||||||
node_ids=[feature1_nodes],
|
node_ids=[feature1_nodes],
|
||||||
gains=[feature1_gains],
|
gains=[feature1_gains],
|
||||||
@ -1769,7 +1769,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
learning_rate=0.1,
|
learning_rate=0.1,
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions],
|
dimension_ids=[feature1_dimensions],
|
||||||
node_ids=[feature1_nodes],
|
node_ids=[feature1_nodes],
|
||||||
gains=[feature1_gains],
|
gains=[feature1_gains],
|
||||||
@ -1942,7 +1942,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING,
|
||||||
learning_rate=0.1,
|
learning_rate=0.1,
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions],
|
dimension_ids=[feature1_dimensions],
|
||||||
node_ids=[feature1_nodes],
|
node_ids=[feature1_nodes],
|
||||||
gains=[feature1_gains],
|
gains=[feature1_gains],
|
||||||
@ -2309,7 +2309,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.PRE_PRUNING,
|
||||||
# tree is going to be finalized now, since we reach depth 2.
|
# tree is going to be finalized now, since we reach depth 2.
|
||||||
max_depth=3,
|
max_depth=3,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[
|
dimension_ids=[
|
||||||
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
feature1_dimensions, feature2_dimensions, feature3_dimensions
|
||||||
],
|
],
|
||||||
@ -3041,7 +3041,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
||||||
max_depth=3,
|
max_depth=3,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
||||||
node_ids=[feature1_nodes, feature2_nodes],
|
node_ids=[feature1_nodes, feature2_nodes],
|
||||||
gains=[feature1_gains, feature2_gains],
|
gains=[feature1_gains, feature2_gains],
|
||||||
@ -3140,7 +3140,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
||||||
max_depth=3,
|
max_depth=3,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions],
|
dimension_ids=[feature1_dimensions],
|
||||||
node_ids=[feature1_nodes],
|
node_ids=[feature1_nodes],
|
||||||
gains=[feature1_gains],
|
gains=[feature1_gains],
|
||||||
@ -3293,7 +3293,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
||||||
max_depth=3,
|
max_depth=3,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions],
|
dimension_ids=[feature1_dimensions],
|
||||||
node_ids=[feature1_nodes],
|
node_ids=[feature1_nodes],
|
||||||
gains=[feature1_gains],
|
gains=[feature1_gains],
|
||||||
@ -3679,7 +3679,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
||||||
node_ids=[feature1_nodes, feature2_nodes],
|
node_ids=[feature1_nodes, feature2_nodes],
|
||||||
gains=[feature1_gains, feature2_gains],
|
gains=[feature1_gains, feature2_gains],
|
||||||
@ -3778,7 +3778,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
||||||
max_depth=2,
|
max_depth=2,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions],
|
dimension_ids=[feature1_dimensions],
|
||||||
node_ids=[feature1_nodes],
|
node_ids=[feature1_nodes],
|
||||||
gains=[feature1_gains],
|
gains=[feature1_gains],
|
||||||
@ -4014,7 +4014,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
|
|||||||
learning_rate=1.0,
|
learning_rate=1.0,
|
||||||
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
pruning_mode=boosted_trees_ops.PruningMode.POST_PRUNING,
|
||||||
max_depth=1,
|
max_depth=1,
|
||||||
feature_ids=feature_ids,
|
feature_ids=[feature_ids],
|
||||||
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
dimension_ids=[feature1_dimensions, feature2_dimensions],
|
||||||
node_ids=[feature1_nodes, feature2_nodes],
|
node_ids=[feature1_nodes, feature2_nodes],
|
||||||
gains=[feature1_gains, feature2_gains],
|
gains=[feature1_gains, feature2_gains],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user