Bug fix; op output validation for aggregate stats and best feature split
PiperOrigin-RevId: 266181494
This commit is contained in:
parent
b32c5e1e35
commit
c3d844efc8
@ -126,10 +126,18 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit")
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
|
||||
ShapeHandle output_shape = c->MakeShape({c->UnknownDim()});
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
c->set_output(i, output_shape);
|
||||
}
|
||||
ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
|
||||
c->set_output(0, rank_1_output_shape);
|
||||
c->set_output(1, rank_1_output_shape);
|
||||
c->set_output(2, rank_1_output_shape);
|
||||
c->set_output(3, rank_1_output_shape);
|
||||
c->set_output(6, rank_1_output_shape);
|
||||
int logits_dimension;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||
ShapeHandle contribs_output_shape =
|
||||
c->MakeShape({c->UnknownDim(), logits_dimension});
|
||||
c->set_output(4, contribs_output_shape);
|
||||
c->set_output(5, contribs_output_shape);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
@ -168,10 +176,18 @@ REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape));
|
||||
ShapeHandle output_shape = c->MakeShape({-1});
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
c->set_output(i, output_shape);
|
||||
}
|
||||
ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
|
||||
c->set_output(0, rank_1_output_shape);
|
||||
c->set_output(1, rank_1_output_shape);
|
||||
c->set_output(2, rank_1_output_shape);
|
||||
c->set_output(3, rank_1_output_shape);
|
||||
c->set_output(6, rank_1_output_shape);
|
||||
int logits_dimension;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||
ShapeHandle contribs_output_shape =
|
||||
c->MakeShape({c->UnknownDim(), logits_dimension});
|
||||
c->set_output(4, contribs_output_shape);
|
||||
c->set_output(5, contribs_output_shape);
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
@ -301,8 +317,9 @@ REGISTER_OP("BoostedTreesAggregateStats")
|
||||
DimensionHandle feature_dim = c->Dim(c->input(3), 1);
|
||||
DimensionHandle stats_dim;
|
||||
TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
|
||||
c->set_output(
|
||||
0, c->MakeShape({max_splits, feature_dim, num_buckets, stats_dim}));
|
||||
c->set_output(0, c->MakeShape({max_splits, feature_dim,
|
||||
num_buckets + 1, // +1 for missing bucket.
|
||||
stats_dim}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user