[TF:MLIR] Add canonicalization pattern to TransposeOp and compose a layout optimizer pipeline

PiperOrigin-RevId: 296081205
Change-Id: Ica9b311ba83e2e75b726eacbdc393c03692dacb8
This commit is contained in:
Eugene Zhulenev 2020-02-19 16:12:20 -08:00 committed by TensorFlower Gardener
parent 2b95bfb6d8
commit edaaeaddbd
4 changed files with 146 additions and 10 deletions

View File

@ -151,6 +151,26 @@ static bool AreCastCompatible(Type a, Type b) {
b_kind == TensorFlowTypes::VARIANT; b_kind == TensorFlowTypes::VARIANT;
} }
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
DenseIntElementsAttr perm1) {
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
if (perm0.getNumElements() != perm1.getNumElements()) return false;
SmallVector<int64_t, 8> perm0_values;
for (auto value : perm0.getIntValues())
perm0_values.push_back(value.getSExtValue());
SmallVector<int64_t, 8> perm1_values;
for (auto value : perm1.getIntValues())
perm1_values.push_back(value.getSExtValue());
for (int i = 0; i < perm0_values.size(); ++i) {
if (perm0_values[perm1_values[i]] != i) return false;
}
return true;
}
static bool IsUnknownDimOrRank(int64_t dim_or_rank) { static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
return dim_or_rank == -1; return dim_or_rank == -1;
} }
@ -2723,23 +2743,46 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
perm); perm);
} }
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) { namespace {
auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm().getDefiningOp());
if (!const_perm) { OpFoldResult FoldIdentityTranspose(TransposeOp op) {
return {}; auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
} if (!const_perm) return {};
auto const_value = const_perm.value(); auto const_value = const_perm.value();
const auto &elements = const_value.getValues<APInt>(); const auto &elements = const_value.getValues<APInt>();
for (auto it : llvm::enumerate(elements)) { for (auto it : llvm::enumerate(elements)) {
if (it.index() != it.value()) { if (it.index() != it.value()) return {};
return {};
}
} }
return x(); return op.x();
}
OpFoldResult FoldCancellableTranspose(TransposeOp op) {
// Operand is a TransposeOp.
auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
if (!transpose) return {};
// Permutations defined by constant operations.
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
if (!perm0 || !perm1) return {};
// With permutation indices that cancel each other
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
return transpose.x();
}
} // namespace
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
if (auto folded = FoldIdentityTranspose(*this)) return folded;
if (auto folded = FoldCancellableTranspose(*this)) return folded;
return {};
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -383,6 +383,28 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32
// CHECK: return %1 // CHECK: return %1
} }
// CHECK-LABEL: @cancellableTranspose
func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
return %3 : tensor<1x4x4x8xf32>
// CHECK: return %arg0
}
// CHECK-LABEL: @nonCancellableTranspose
func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
%1 = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<4x1x4x8xf32>
return %3 : tensor<4x1x4x8xf32>
// CHECK: return %3
}
// CHECK-LABEL: func @addN // CHECK-LABEL: func @addN
func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: return %arg0 // CHECK: return %arg0

View File

@ -0,0 +1,24 @@
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
// CHECK-LABEL: func @transposeBiasAdd
func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> {
// Convert input: NCHW -> NHWC
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32>
// Compute in NHWC
%2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
// Convert result back: NHWC -> NCHW
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
// Check that BiasAdd computed in NCHW format, and all redundant transpose
// operations removed from the function.
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
// CHECK: return %[[BIAS_ADD]]
return %4 : tensor<1x8x4x4xf32>
}

View File

@ -18,7 +18,9 @@ limitations under the License.
#include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Pass/PassManager.h" // TF:llvm-project
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project #include "mlir/Pass/PassRegistry.h" // TF:llvm-project
#include "mlir/Transforms/Passes.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#define DEBUG_TYPE "tf-layout-optimization" #define DEBUG_TYPE "tf-layout-optimization"
@ -28,11 +30,25 @@ namespace TF {
namespace { namespace {
// Layout optimization pipeline composes layout assignment and move transposes
// passes to pick the optimal layout for all layout sensitive operations, and
// cancel all redundant transposes.
struct LayoutOptimizationPipelineOptions
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
Option<std::string> force_data_format{
*this, "force-data-format",
llvm::cl::desc("Force data format for all layout sensitive ops")};
};
// LayoutAssignmentPass assigns optimal data layout (data format) for all // LayoutAssignmentPass assigns optimal data layout (data format) for all
// layout sensitive operations. // layout sensitive operations.
class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> { class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
public: public:
LayoutAssignmentPass() = default; LayoutAssignmentPass() = default;
explicit LayoutAssignmentPass(const std::string& force_data_format) {
force_data_format_ = force_data_format;
}
LayoutAssignmentPass(const LayoutAssignmentPass& pass) {} LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
void runOnFunction() final; void runOnFunction() final;
@ -52,6 +68,7 @@ class MoveTransposesPass : public FunctionPass<MoveTransposesPass> {
enum class Direction { kBegin, kEnd }; enum class Direction { kBegin, kEnd };
MoveTransposesPass() = default; MoveTransposesPass() = default;
explicit MoveTransposesPass(Direction direction) { direction_ = direction; }
MoveTransposesPass(const MoveTransposesPass& pass) {} MoveTransposesPass(const MoveTransposesPass& pass) {}
void runOnFunction() final; void runOnFunction() final;
@ -356,6 +373,30 @@ void MoveTransposesPass::runOnFunction() {
MoveTransposeAfter(op, &work_list); MoveTransposeAfter(op, &work_list);
} }
} }
func.walk([&](TransposeOp transpose) {
OpBuilder builder(transpose);
SmallVector<Value, 1> fold_result;
if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
assert(fold_result.size() == 1);
transpose.replaceAllUsesWith(fold_result[0]);
}
});
}
void CreateLayoutOptimizationPipeline(
OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference.
const LayoutOptimizationPipelineOptions& options) {
using Direction = MoveTransposesPass::Direction;
// Assign optimal layout for layout sensitive ops.
pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
// Move transposes to the beginning of the block and try to fold them.
pm.addPass(std::make_unique<MoveTransposesPass>(Direction::kBegin));
// Move transposes to the end of the block and try to fold them.
pm.addPass(std::make_unique<MoveTransposesPass>(Direction::kEnd));
} }
} // namespace } // namespace
@ -365,5 +406,11 @@ static PassRegistration<LayoutAssignmentPass> layout_assignment(
static PassRegistration<MoveTransposesPass> move_transposes( static PassRegistration<MoveTransposesPass> move_transposes(
"tf-move-transposes", "Move transposes pass"); "tf-move-transposes", "Move transposes pass");
static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
pipeline("tf-layout-optimization",
"Assigns optimal data layout to all layout sensitive operations "
"and cancel redundant transpose operations.",
CreateLayoutOptimizationPipeline);
} // namespace TF } // namespace TF
} // namespace mlir } // namespace mlir