Clearly indicate when to disable optimizations

Many ops mark themselves as stateful even though they primarily want to
disable optimizations such as constant folding and CSE. We add a new
method to clearly indicate this intent even though we are currently not
adding a new flag.

PiperOrigin-RevId: 309253555
Change-Id: I8cae8bbc4c3b71819ee869b1870fce1e39e061be
This commit is contained in:
Gaurav Jain 2020-04-30 10:33:42 -07:00 committed by TensorFlower Gardener
parent 7652750f96
commit cccfd47023
5 changed files with 53 additions and 46 deletions

View File

@ -849,17 +849,17 @@ REGISTER_OP("TensorAsShapeInt64")
REGISTER_OP("NonConstScalarInt32")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetDoNotOptimize()
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("NonConstScalarInt64")
.Output("o: int64")
.SetIsStateful() // prevents constant folding
.SetDoNotOptimize()
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("WithEmptyVectorShape")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetDoNotOptimize()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(0));
return Status::OK();
@ -867,7 +867,7 @@ REGISTER_OP("WithEmptyVectorShape")
REGISTER_OP("WithPartialShape")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetDoNotOptimize()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(
0, c->MakeShape({1, shape_inference::InferenceContext::kUnknownDim, 3,
@ -877,7 +877,7 @@ REGISTER_OP("WithPartialShape")
REGISTER_OP("WithPartialShape2")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetDoNotOptimize()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(
0,
@ -887,7 +887,7 @@ REGISTER_OP("WithPartialShape2")
REGISTER_OP("WithUnknownShape")
.Output("o: int32")
.SetIsStateful() // prevents constant folding
.SetDoNotOptimize()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShape());
return Status::OK();

View File

@ -249,6 +249,12 @@ class OpDefBuilderWrapper<true> {
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper<true>& SetDoNotOptimize() {
// We don't have a separate flag to disable optimizations such as constant
// folding and CSE so we reuse the stateful flag.
builder_.SetIsStateful();
return *this;
}
OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
builder_.SetAllowsUninitializedInput();
return *this;
@ -282,6 +288,7 @@ class OpDefBuilderWrapper<false> {
OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
OpDefBuilderWrapper<false>& SetDoNotOptimize() { return *this; }
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }

View File

@ -402,7 +402,7 @@ REGISTER_OP("Empty")
.Output("output: dtype")
.Attr("dtype: type")
.Attr("init: bool = false")
.SetIsStateful()
.SetDoNotOptimize()
.SetShapeFn([](InferenceContext* c) {
ShapeHandle out;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
@ -744,7 +744,7 @@ REGISTER_OP("GuaranteeConst")
return UnchangedShape(c);
})
// We don't want this to be optimized away.
.SetIsStateful();
.SetDoNotOptimize();
// --------------------------------------------------------------------------
REGISTER_OP("ZerosLike")

View File

@ -37,8 +37,8 @@ REGISTER_OP("TensorDataset")
.Output("handle: variant")
.Attr("Toutput_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that
// `components` have shapes
// compatible with
@ -49,8 +49,8 @@ REGISTER_OP("TensorSliceDataset")
.Output("handle: variant")
.Attr("Toutput_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that the
// dim-0 slices of `components`
// have shapes compatible with
@ -62,8 +62,8 @@ REGISTER_OP("SparseTensorSliceDataset")
.Input("dense_shape: int64")
.Output("handle: variant")
.Attr("Tvalues: type")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("GeneratorDataset")
@ -79,8 +79,8 @@ REGISTER_OP("GeneratorDataset")
.Attr("Tfinalize_func_args: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ZipDataset")
@ -392,8 +392,8 @@ REGISTER_OP("RangeDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// start, stop, and step should be scalars.
@ -595,8 +595,8 @@ REGISTER_OP("TextLineDataset")
.Input("compression_type: string")
.Input("buffer_size: int64")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
@ -615,8 +615,8 @@ REGISTER_OP("FixedLengthRecordDataset")
.Input("footer_bytes: int64")
.Input("buffer_size: int64")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
@ -638,8 +638,8 @@ REGISTER_OP("FixedLengthRecordDatasetV2")
.Input("buffer_size: int64")
.Input("compression_type: string")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
@ -658,8 +658,8 @@ REGISTER_OP("TFRecordDataset")
.Input("compression_type: string")
.Input("buffer_size: int64")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.

View File

@ -145,8 +145,8 @@ REGISTER_OP("CSVDataset")
.Output("handle: variant")
.Attr("output_types: list({float,double,int32,int64,string}) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
@ -187,8 +187,8 @@ REGISTER_OP("ExperimentalCSVDataset")
.Output("handle: variant")
.Attr("output_types: list({float,double,int32,int64,string}) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `filenames` must be a scalar or a vector.
@ -426,8 +426,8 @@ REGISTER_OP("LMDBDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalLMDBDataset")
@ -435,8 +435,8 @@ REGISTER_OP("ExperimentalLMDBDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapAndBatchDataset")
@ -508,8 +508,8 @@ REGISTER_OP("ExperimentalMapDataset")
REGISTER_OP("MatchingFilesDataset")
.Input("patterns: string")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `patterns` must be a scalar or a vector.
@ -520,8 +520,8 @@ REGISTER_OP("MatchingFilesDataset")
REGISTER_OP("ExperimentalMatchingFilesDataset")
.Input("patterns: string")
.Output("handle: variant")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// `patterns` must be a scalar or a vector.
@ -689,8 +689,8 @@ REGISTER_OP("ExperimentalRandomDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// buffer_size, seed, and seed2 should be scalars.
@ -705,8 +705,8 @@ REGISTER_OP("RandomDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// buffer_size, seed, and seed2 should be scalars.
@ -893,8 +893,8 @@ REGISTER_OP("SqlDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// driver_name, data_source_name, and query should be scalars.
@ -911,8 +911,8 @@ REGISTER_OP("ExperimentalSqlDataset")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
// disable constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// driver_name, data_source_name, and query should be scalars.