Add a pass to drop shape_invariant
attribute from TF While/WhileRegion ops.
`shape_invariant` attribute relaxes the constraint on While/WhileRegion ops to have separate operand and result shapes. This is to support changing shapes in each iteration of loop body. For cases when this relaxation is not required (for example when compiling to XLA) this pass can be used to drop the attribute. PiperOrigin-RevId: 346856163 Change-Id: I91477151be4c28e77d39ff62e715e0822526c3b7
This commit is contained in:
parent
e2427aa68b
commit
e7eca6bc4d
@ -869,6 +869,7 @@ cc_library(
|
||||
"transforms/contraction_fusion.cc",
|
||||
"transforms/decompose_resource_ops_pass.cc",
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/drop_while_shape_invariant.cc",
|
||||
"transforms/einsum.cc",
|
||||
"transforms/executor_island_coarsening.cc",
|
||||
"transforms/executor_tpuv1_inline_tpu_island.cc",
|
||||
|
@ -0,0 +1,29 @@
|
||||
// RUN: tf-opt %s -tf-drop-while-shape-invariant | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: while_shape_invariant
|
||||
// CHECK-NOT: shape_invariant
|
||||
func @while_shape_invariant(%arg0: tensor<4xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0 = "tf.While"(%arg0) {cond = @while_cond, body = @while_body, is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
|
||||
|
||||
%1 = "tf.WhileRegion"(%arg0) ( {
|
||||
^cond(%carg0: tensor<*xf32>):
|
||||
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
"tf.Yield"(%2) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body(%barg0: tensor<*xf32>):
|
||||
%2 = "tf.SomeOp"(%barg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
"tf.Yield"(%2) : (tensor<*xf32>) -> ()
|
||||
}) {is_stateless = false, shape_invariant} : (tensor<4xf32>) -> (tensor<*xf32>)
|
||||
|
||||
return %0, %1 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @while_cond(%arg0: tensor<*xf32>) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
func @while_body(%arg0: tensor<*xf32>) -> (tensor<*xf32>) {
|
||||
%0 = "tf.SomeOp"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kShapeInvariantAttr[] = "shape_invariant";
|
||||
|
||||
// Drop `shape_invariant` attribute from tf.While and tf.WhileRegion op. This
|
||||
// would allow shape inference pass to further refine operand/result shapes of
|
||||
// these ops. This is only safe to do when compiling to XLA.
|
||||
class DropWhileShapeInvariantPass
|
||||
: public PassWrapper<DropWhileShapeInvariantPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void DropWhileShapeInvariantPass::runOnFunction() {
|
||||
getFunction().walk([](Operation* op) {
|
||||
if (llvm::isa<WhileOp, WhileRegionOp>(op))
|
||||
op->removeAttr(kShapeInvariantAttr);
|
||||
});
|
||||
}
|
||||
|
||||
static PassRegistration<DropWhileShapeInvariantPass> pass(
|
||||
"tf-drop-while-shape-invariant",
|
||||
"Drop `shape_invariant` attrbute from While/WhileRegion ops.");
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass() {
|
||||
return std::make_unique<DropWhileShapeInvariantPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -39,6 +39,10 @@ std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateExecutorDialectToFunctionalConversionPass();
|
||||
|
||||
namespace TF {
|
||||
// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion
|
||||
// ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDropWhileShapeInvariantPass();
|
||||
|
||||
// Transforms functional control flow operations in the TensorFlow dialect to
|
||||
// MLIR Control Flow Graph (CFG) form.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTFFunctionalControlFlowToCFG();
|
||||
|
Loading…
Reference in New Issue
Block a user