Bug fix; op output validation for aggregate stats and best feature split
PiperOrigin-RevId: 266181494
This commit is contained in:
parent
b32c5e1e35
commit
c3d844efc8
@ -126,10 +126,18 @@ REGISTER_OP("BoostedTreesCalculateBestFeatureSplit")
|
|||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
|
||||||
ShapeHandle output_shape = c->MakeShape({c->UnknownDim()});
|
ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
|
||||||
for (int i = 0; i < 7; ++i) {
|
c->set_output(0, rank_1_output_shape);
|
||||||
c->set_output(i, output_shape);
|
c->set_output(1, rank_1_output_shape);
|
||||||
}
|
c->set_output(2, rank_1_output_shape);
|
||||||
|
c->set_output(3, rank_1_output_shape);
|
||||||
|
c->set_output(6, rank_1_output_shape);
|
||||||
|
int logits_dimension;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||||
|
ShapeHandle contribs_output_shape =
|
||||||
|
c->MakeShape({c->UnknownDim(), logits_dimension});
|
||||||
|
c->set_output(4, contribs_output_shape);
|
||||||
|
c->set_output(5, contribs_output_shape);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -168,10 +176,18 @@ REGISTER_OP("BoostedTreesSparseCalculateBestFeatureSplit")
|
|||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused_shape));
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused_shape));
|
||||||
ShapeHandle output_shape = c->MakeShape({-1});
|
ShapeHandle rank_1_output_shape = c->MakeShape({c->UnknownDim()});
|
||||||
for (int i = 0; i < 7; ++i) {
|
c->set_output(0, rank_1_output_shape);
|
||||||
c->set_output(i, output_shape);
|
c->set_output(1, rank_1_output_shape);
|
||||||
}
|
c->set_output(2, rank_1_output_shape);
|
||||||
|
c->set_output(3, rank_1_output_shape);
|
||||||
|
c->set_output(6, rank_1_output_shape);
|
||||||
|
int logits_dimension;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("logits_dimension", &logits_dimension));
|
||||||
|
ShapeHandle contribs_output_shape =
|
||||||
|
c->MakeShape({c->UnknownDim(), logits_dimension});
|
||||||
|
c->set_output(4, contribs_output_shape);
|
||||||
|
c->set_output(5, contribs_output_shape);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -301,8 +317,9 @@ REGISTER_OP("BoostedTreesAggregateStats")
|
|||||||
DimensionHandle feature_dim = c->Dim(c->input(3), 1);
|
DimensionHandle feature_dim = c->Dim(c->input(3), 1);
|
||||||
DimensionHandle stats_dim;
|
DimensionHandle stats_dim;
|
||||||
TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
|
TF_RETURN_IF_ERROR(c->Add(logits_dim, hessian_dim, &stats_dim));
|
||||||
c->set_output(
|
c->set_output(0, c->MakeShape({max_splits, feature_dim,
|
||||||
0, c->MakeShape({max_splits, feature_dim, num_buckets, stats_dim}));
|
num_buckets + 1, // +1 for missing bucket.
|
||||||
|
stats_dim}));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user