Bug fix; op output validation for aggregate stats and best feature split

PiperOrigin-RevId: 266181494
This commit is contained in:
A. Unique TensorFlower 2019-08-29 11:06:15 -07:00 committed by TensorFlower Gardener
parent b32c5e1e35
commit c3d844efc8

View File

@ -126,10 +126,18 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit")
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
ShapeHandle output_shape = c->MakeShape({c->UnknownDim()});
for (int i = 0; i < 7; ++i) {
c->set_output(i, output_shape);
}
ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
c->set_output(0, rank_1_output_shape);
c->set_output(1, rank_1_output_shape);
c->set_output(2, rank_1_output_shape);
c->set_output(3, rank_1_output_shape);
c->set_output(6, rank_1_output_shape);
int logits_dimension;
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
ShapeHandle contribs_output_shape =
c->MakeShape({c->UnknownDim(), logits_dimension});
c->set_output(4, contribs_output_shape);
c->set_output(5, contribs_output_shape);
return Status::OK();
});
@ -168,10 +176,18 @@ REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape));
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape));
ShapeHandle output_shape = c->MakeShape({-1});
for (int i = 0; i < 7; ++i) {
c->set_output(i, output_shape);
}
ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
c->set_output(0, rank_1_output_shape);
c->set_output(1, rank_1_output_shape);
c->set_output(2, rank_1_output_shape);
c->set_output(3, rank_1_output_shape);
c->set_output(6, rank_1_output_shape);
int logits_dimension;
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
ShapeHandle contribs_output_shape =
c->MakeShape({c->UnknownDim(), logits_dimension});
c->set_output(4, contribs_output_shape);
c->set_output(5, contribs_output_shape);
return Status::OK();
});
@ -301,8 +317,9 @@ REGISTER_OP("BoostedTreesAggregateStats")
DimensionHandle feature_dim = c->Dim(c->input(3), 1);
DimensionHandle stats_dim;
TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
c->set_output(
0, c->MakeShape({max_splits, feature_dim, num_buckets, stats_dim}));
c->set_output(0, c->MakeShape({max_splits, feature_dim,
num_buckets + 1, // +1 for missing bucket.
stats_dim}));
return Status::OK();
});