Add a pass to drop shape_invariant attribute from TF While/WhileRegion ops.

`shape_invariant` attribute relaxes the constraint on While/WhileRegion ops to have separate operand and result shapes. This is to support changing shapes in each iteration of loop body. For cases when this relaxation is not required (for example when compiling to XLA) this pass can be used to drop the attribute.

PiperOrigin-RevId: 346856163
Change-Id: I91477151be4c28e77d39ff62e715e0822526c3b7
This commit is contained in:
Prakalp Srivastava 2020-12-10 13:43:07 -08:00 committed by TensorFlower Gardener
parent e2427aa68b
commit e7eca6bc4d
4 changed files with 87 additions and 0 deletions

View File

@ -869,6 +869,7 @@ cc_library(
"transforms/contraction_fusion.cc",
"transforms/decompose_resource_ops_pass.cc",
"transforms/device_index_selector.cc",
"transforms/drop_while_shape_invariant.cc",
"transforms/einsum.cc",
"transforms/executor_island_coarsening.cc",
"transforms/executor_tpuv1_inline_tpu_island.cc",

View File

@ -0,0 +1,29 @@
// RUN: tf-opt %s -tf-drop-while-shape-invariant | FileCheck %s
// CHECK-LABEL: while_shape_invariant
// CHECK-NOT: shape_invariant
func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
%1 = "tf.WhileRegion"(%arg0) ( {
^cond(%carg0: tensor<*xf32>):
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
"tf.Yield"(%2) : (tensor<i1>) -> ()
}, {
^body(%barg0: tensor<*xf32>):
%2 = "tf.SomeOp"(%barg0) : (tensor<*xf32>) -> tensor<*xf32>
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
}) {is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
return %0, %1 : tensor<*xf32>, tensor<*xf32>
}
func @while_cond(%arg0: tensor<*xf32>) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
return %0 : tensor<i1>
}
func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) {
%0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
constexpr char kShapeInvariantAttr[] = "shape_invariant";
// Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op. This
// would allow shape inference pass to further refine operand/result shapes of
// these ops. This is only safe to do when compiling to XLA.
class DropWhileShapeInvariantPass
: public PassWrapper<DropWhileShapeInvariantPass, FunctionPass> {
void runOnFunction() override;
};
void DropWhileShapeInvariantPass::runOnFunction() {
getFunction().walk([](Operation* op) {
if (llvm::isa<WhileOp, WhileRegionOp>(op))
op->removeAttr(kShapeInvariantAttr);
});
}
static PassRegistration<DropWhileShapeInvariantPass> pass(
"tf-drop-while-shape-invariant",
"Drop `shape_invariant` attrbute from While/WhileRegion ops.");
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass() {
return std::make_unique<DropWhileShapeInvariantPass>();
}
} // namespace TF
} // namespace mlir

View File

@ -39,6 +39,10 @@ std::unique_ptr<OperationPass<FuncOp>>
CreateExecutorDialectToFunctionalConversionPass();
namespace TF {
// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
// ops.
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass();
// Transforms functional control flow operations in the TensorFlow dialect to
// MLIR Control Flow Graph (CFG) form.
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();