From c3d844efc8c9eb85eead00e43714c1599d53ae13 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 29 Aug 2019 11:06:15 -0700 Subject: [PATCH] Bug fix; op output validation for aggregate stats and best feature split PiperOrigin-RevId: 266181494 --- tensorflow/core/ops/boosted_trees_ops.cc | 37 +++++++++++++++++------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index 39fbd1606cf..d028ceb7e6d 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -126,10 +126,18 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit") TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); - ShapeHandle output_shape = c->MakeShape({c->UnknownDim()}); - for (int i = 0; i < 7; ++i) { - c->set_output(i, output_shape); - } + ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); + c->set_output(0, rank_1_output_shape); + c->set_output(1, rank_1_output_shape); + c->set_output(2, rank_1_output_shape); + c->set_output(3, rank_1_output_shape); + c->set_output(6, rank_1_output_shape); + int logits_dimension; + TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); + ShapeHandle contribs_output_shape = + c->MakeShape({c->UnknownDim(), logits_dimension}); + c->set_output(4, contribs_output_shape); + c->set_output(5, contribs_output_shape); return Status::OK(); }); @@ -168,10 +176,18 @@ REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit") TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape)); TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape)); TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape)); - ShapeHandle output_shape = c->MakeShape({-1}); - for (int i = 0; i < 7; ++i) { - c->set_output(i, output_shape); - } + ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()}); + c->set_output(0, rank_1_output_shape); + c->set_output(1, rank_1_output_shape); + c->set_output(2, rank_1_output_shape); + c->set_output(3, rank_1_output_shape); + c->set_output(6, rank_1_output_shape); + int logits_dimension; + TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension)); + ShapeHandle contribs_output_shape = + c->MakeShape({c->UnknownDim(), logits_dimension}); + c->set_output(4, contribs_output_shape); + c->set_output(5, contribs_output_shape); return Status::OK(); }); @@ -301,8 +317,9 @@ REGISTER_OP("BoostedTreesAggregateStats") DimensionHandle feature_dim = c->Dim(c->input(3), 1); DimensionHandle stats_dim; TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim)); - c->set_output( - 0, c->MakeShape({max_splits, feature_dim, num_buckets, stats_dim})); + c->set_output(0, c->MakeShape({max_splits, feature_dim, + num_buckets + 1, // +1 for missing bucket. + stats_dim})); return Status::OK(); });