Clearly indicate when to disable optimizations
Many ops mark themselves as stateful even though they primarily want to disable optimizations such as constant folding and CSE. We add a new method to clearly indicate this intent even though we are currently not adding a new flag. PiperOrigin-RevId: 309253555 Change-Id: I8cae8bbc4c3b71819ee869b1870fce1e39e061be
This commit is contained in:
parent
7652750f96
commit
cccfd47023
tensorflow/core
common_runtime
framework
ops
@ -849,17 +849,17 @@ REGISTER_OP("TensorAsShapeInt64")
|
||||
|
||||
REGISTER_OP("NonConstScalarInt32")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetDoNotOptimize()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("NonConstScalarInt64")
|
||||
.Output("o: int64")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetDoNotOptimize()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("WithEmptyVectorShape")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetDoNotOptimize()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(0));
|
||||
return Status::OK();
|
||||
@ -867,7 +867,7 @@ REGISTER_OP("WithEmptyVectorShape")
|
||||
|
||||
REGISTER_OP("WithPartialShape")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetDoNotOptimize()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(
|
||||
0, c->MakeShape({1, shape_inference::InferenceContext::kUnknownDim, 3,
|
||||
@ -877,7 +877,7 @@ REGISTER_OP("WithPartialShape")
|
||||
|
||||
REGISTER_OP("WithPartialShape2")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetDoNotOptimize()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(
|
||||
0,
|
||||
@ -887,7 +887,7 @@ REGISTER_OP("WithPartialShape2")
|
||||
|
||||
REGISTER_OP("WithUnknownShape")
|
||||
.Output("o: int32")
|
||||
.SetIsStateful() // prevents constant folding
|
||||
.SetDoNotOptimize()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
|
@ -249,6 +249,12 @@ class OpDefBuilderWrapper<true> {
|
||||
builder_.SetIsStateful();
|
||||
return *this;
|
||||
}
|
||||
OpDefBuilderWrapper<true>& SetDoNotOptimize() {
|
||||
// We don't have a separate flag to disable optimizations such as constant
|
||||
// folding and CSE so we reuse the stateful flag.
|
||||
builder_.SetIsStateful();
|
||||
return *this;
|
||||
}
|
||||
OpDefBuilderWrapper<true>& SetAllowsUninitializedInput() {
|
||||
builder_.SetAllowsUninitializedInput();
|
||||
return *this;
|
||||
@ -282,6 +288,7 @@ class OpDefBuilderWrapper<false> {
|
||||
OpDefBuilderWrapper<false>& SetIsCommutative() { return *this; }
|
||||
OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
|
||||
OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
|
||||
OpDefBuilderWrapper<false>& SetDoNotOptimize() { return *this; }
|
||||
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
|
||||
OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
|
||||
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
|
||||
|
@ -402,7 +402,7 @@ REGISTER_OP("Empty")
|
||||
.Output("output: dtype")
|
||||
.Attr("dtype: type")
|
||||
.Attr("init: bool = false")
|
||||
.SetIsStateful()
|
||||
.SetDoNotOptimize()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle out;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
|
||||
@ -744,7 +744,7 @@ REGISTER_OP("GuaranteeConst")
|
||||
return UnchangedShape(c);
|
||||
})
|
||||
// We don't want this to be optimized away.
|
||||
.SetIsStateful();
|
||||
.SetDoNotOptimize();
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("ZerosLike")
|
||||
|
@ -37,8 +37,8 @@ REGISTER_OP("TensorDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("Toutput_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that
|
||||
// `components` have shapes
|
||||
// compatible with
|
||||
@ -49,8 +49,8 @@ REGISTER_OP("TensorSliceDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("Toutput_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that the
|
||||
// dim-0 slices of `components`
|
||||
// have shapes compatible with
|
||||
@ -62,8 +62,8 @@ REGISTER_OP("SparseTensorSliceDataset")
|
||||
.Input("dense_shape: int64")
|
||||
.Output("handle: variant")
|
||||
.Attr("Tvalues: type")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("GeneratorDataset")
|
||||
@ -79,8 +79,8 @@ REGISTER_OP("GeneratorDataset")
|
||||
.Attr("Tfinalize_func_args: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ZipDataset")
|
||||
@ -392,8 +392,8 @@ REGISTER_OP("RangeDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// start, stop, and step should be scalars.
|
||||
@ -595,8 +595,8 @@ REGISTER_OP("TextLineDataset")
|
||||
.Input("compression_type: string")
|
||||
.Input("buffer_size: int64")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `filenames` must be a scalar or a vector.
|
||||
@ -615,8 +615,8 @@ REGISTER_OP("FixedLengthRecordDataset")
|
||||
.Input("footer_bytes: int64")
|
||||
.Input("buffer_size: int64")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `filenames` must be a scalar or a vector.
|
||||
@ -638,8 +638,8 @@ REGISTER_OP("FixedLengthRecordDatasetV2")
|
||||
.Input("buffer_size: int64")
|
||||
.Input("compression_type: string")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `filenames` must be a scalar or a vector.
|
||||
@ -658,8 +658,8 @@ REGISTER_OP("TFRecordDataset")
|
||||
.Input("compression_type: string")
|
||||
.Input("buffer_size: int64")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `filenames` must be a scalar or a vector.
|
||||
|
@ -145,8 +145,8 @@ REGISTER_OP("CSVDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list({float,double,int32,int64,string}) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `filenames` must be a scalar or a vector.
|
||||
@ -187,8 +187,8 @@ REGISTER_OP("ExperimentalCSVDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list({float,double,int32,int64,string}) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `filenames` must be a scalar or a vector.
|
||||
@ -426,8 +426,8 @@ REGISTER_OP("LMDBDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalLMDBDataset")
|
||||
@ -435,8 +435,8 @@ REGISTER_OP("ExperimentalLMDBDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("MapAndBatchDataset")
|
||||
@ -508,8 +508,8 @@ REGISTER_OP("ExperimentalMapDataset")
|
||||
REGISTER_OP("MatchingFilesDataset")
|
||||
.Input("patterns: string")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `patterns` must be a scalar or a vector.
|
||||
@ -520,8 +520,8 @@ REGISTER_OP("MatchingFilesDataset")
|
||||
REGISTER_OP("ExperimentalMatchingFilesDataset")
|
||||
.Input("patterns: string")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// `patterns` must be a scalar or a vector.
|
||||
@ -689,8 +689,8 @@ REGISTER_OP("ExperimentalRandomDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, and seed2 should be scalars.
|
||||
@ -705,8 +705,8 @@ REGISTER_OP("RandomDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// buffer_size, seed, and seed2 should be scalars.
|
||||
@ -893,8 +893,8 @@ REGISTER_OP("SqlDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// driver_name, data_source_name, and query should be scalars.
|
||||
@ -911,8 +911,8 @@ REGISTER_OP("ExperimentalSqlDataset")
|
||||
.Output("handle: variant")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
|
||||
// stateful to inhibit constant folding.
|
||||
.SetDoNotOptimize() // TODO(b/123753214): Source dataset ops must
|
||||
// disable constant folding.
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle unused;
|
||||
// driver_name, data_source_name, and query should be scalars.
|
||||
|
Loading…
Reference in New Issue
Block a user