From edaaeaddbdf996a089b3041c0d8fe4677e37c9e0 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 19 Feb 2020 16:12:20 -0800 Subject: [PATCH] [TF:MLIR] Add canonicalization pattern to TransposeOp and compose a layout optimizer pipeline PiperOrigin-RevId: 296081205 Change-Id: Ica9b311ba83e2e75b726eacbdc393c03692dacb8 --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 63 ++++++++++++++++--- .../mlir/tensorflow/tests/canonicalize.mlir | 22 +++++++ .../tensorflow/tests/layout_optimization.mlir | 24 +++++++ .../transforms/layout_optimization.cc | 47 ++++++++++++++ 4 files changed, 146 insertions(+), 10 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 0d70d8793ee..c97f2ed5420 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -151,6 +151,26 @@ static bool AreCastCompatible(Type a, Type b) { b_kind == TensorFlowTypes::VARIANT; } +static bool AreCancellablePermutations(DenseIntElementsAttr perm0, + DenseIntElementsAttr perm1) { + if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false; + if (perm0.getNumElements() != perm1.getNumElements()) return false; + + SmallVector perm0_values; + for (auto value : perm0.getIntValues()) + perm0_values.push_back(value.getSExtValue()); + + SmallVector perm1_values; + for (auto value : perm1.getIntValues()) + perm1_values.push_back(value.getSExtValue()); + + for (int i = 0; i < perm0_values.size(); ++i) { + if (perm0_values[perm1_values[i]] != i) return false; + } + + return true; +} + static bool IsUnknownDimOrRank(int64_t dim_or_rank) { return dim_or_rank == -1; } @@ -2723,23 +2743,46 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x, perm); } -OpFoldResult TransposeOp::fold(ArrayRef operands) { - auto const_perm = dyn_cast_or_null(perm().getDefiningOp()); +namespace { - if (!const_perm) { - return {}; - } +OpFoldResult FoldIdentityTranspose(TransposeOp op) { + auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); + if (!const_perm) return {}; auto const_value = const_perm.value(); - const auto &elements = const_value.getValues(); + for (auto it : llvm::enumerate(elements)) { - if (it.index() != it.value()) { - return {}; - } + if (it.index() != it.value()) return {}; } - return x(); + return op.x(); +} + +OpFoldResult FoldCancellableTranspose(TransposeOp op) { + // Operand is a TransposeOp. + auto transpose = dyn_cast_or_null(op.x().getDefiningOp()); + if (!transpose) return {}; + + // Permutations defined by constant operations. + auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); + auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); + if (!perm0 || !perm1) return {}; + + // With permutation indices that cancel each other + auto perm0_value = perm0.value().cast(); + auto perm1_value = perm1.value().cast(); + if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; + + return transpose.x(); +} + +} // namespace + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + if (auto folded = FoldIdentityTranspose(*this)) return folded; + if (auto folded = FoldCancellableTranspose(*this)) return folded; + return {}; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index c91c1e2f7b5..5bf5b0610ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -383,6 +383,28 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32 // CHECK: return %1 } +// CHECK-LABEL: @cancellableTranspose +func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + + return %3 : tensor<1x4x4x8xf32> + // CHECK: return %arg0 +} + +// CHECK-LABEL: @nonCancellableTranspose +func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> { + %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %1 = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<4x1x4x8xf32> + + return %3 : tensor<4x1x4x8xf32> + // CHECK: return %3 +} + // CHECK-LABEL: func @addN func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return %arg0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir new file mode 100644 index 00000000000..44330d675e2 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization.mlir @@ -0,0 +1,24 @@ +// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always + +// CHECK-LABEL: func @transposeBiasAdd +func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> { + + // Convert input: NCHW -> NHWC + %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32> + + // Compute in NHWC + %2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + + // Convert result back: NHWC -> NCHW + %3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64> + %4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32> + + // Check that BiasAdd computed in NCHW format, and all redundant transpose + // operations removed from the function. + + // CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32> + // CHECK: return %[[BIAS_ADD]] + + return %4 : tensor<1x8x4x4xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index ba46059e5b6..feef3516ade 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -18,7 +18,9 @@ limitations under the License. #include "mlir/IR/Builders.h" // TF:llvm-project #include "mlir/IR/Function.h" // TF:llvm-project #include "mlir/Pass/Pass.h" // TF:llvm-project +#include "mlir/Pass/PassManager.h" // TF:llvm-project #include "mlir/Pass/PassRegistry.h" // TF:llvm-project +#include "mlir/Transforms/Passes.h" // TF:llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #define DEBUG_TYPE "tf-layout-optimization" @@ -28,11 +30,25 @@ namespace TF { namespace { +// Layout optimization pipeline composes layout assignment and move transposes +// passes to pick the optimal layout for all layout sensitive operations, and +// cancel all redundant transposes. +struct LayoutOptimizationPipelineOptions + : public PassPipelineOptions { + Option force_data_format{ + *this, "force-data-format", + llvm::cl::desc("Force data format for all layout sensitive ops")}; +}; + // LayoutAssignmentPass assigns optimal data layout (data format) for all // layout sensitive operations. class LayoutAssignmentPass : public FunctionPass { public: LayoutAssignmentPass() = default; + explicit LayoutAssignmentPass(const std::string& force_data_format) { + force_data_format_ = force_data_format; + } + LayoutAssignmentPass(const LayoutAssignmentPass& pass) {} void runOnFunction() final; @@ -52,6 +68,7 @@ class MoveTransposesPass : public FunctionPass { enum class Direction { kBegin, kEnd }; MoveTransposesPass() = default; + explicit MoveTransposesPass(Direction direction) { direction_ = direction; } MoveTransposesPass(const MoveTransposesPass& pass) {} void runOnFunction() final; @@ -356,6 +373,30 @@ void MoveTransposesPass::runOnFunction() { MoveTransposeAfter(op, &work_list); } } + + func.walk([&](TransposeOp transpose) { + OpBuilder builder(transpose); + SmallVector fold_result; + if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) { + assert(fold_result.size() == 1); + transpose.replaceAllUsesWith(fold_result[0]); + } + }); +} + +void CreateLayoutOptimizationPipeline( + OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference. + const LayoutOptimizationPipelineOptions& options) { + using Direction = MoveTransposesPass::Direction; + + // Assign optimal layout for layout sensitive ops. + pm.addPass(std::make_unique(options.force_data_format)); + + // Move transposes to the beginning of the block and try to fold them. + pm.addPass(std::make_unique(Direction::kBegin)); + + // Move transposes to the end of the block and try to fold them. + pm.addPass(std::make_unique(Direction::kEnd)); } } // namespace @@ -365,5 +406,11 @@ static PassRegistration layout_assignment( static PassRegistration move_transposes( "tf-move-transposes", "Move transposes pass"); +static mlir::PassPipelineRegistration + pipeline("tf-layout-optimization", + "Assigns optimal data layout to all layout sensitive operations " + "and cancel redundant transpose operations.", + CreateLayoutOptimizationPipeline); + } // namespace TF } // namespace mlir