[TF:MLIR] Add canonicalization pattern to TransposeOp and compose a layout optimizer pipeline
PiperOrigin-RevId: 296081205 Change-Id: Ica9b311ba83e2e75b726eacbdc393c03692dacb8
This commit is contained in:
parent
2b95bfb6d8
commit
edaaeaddbd
@ -151,6 +151,26 @@ static bool AreCastCompatible(Type a, Type b) {
|
|||||||
b_kind == TensorFlowTypes::VARIANT;
|
b_kind == TensorFlowTypes::VARIANT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool AreCancellablePermutations(DenseIntElementsAttr perm0,
|
||||||
|
DenseIntElementsAttr perm1) {
|
||||||
|
if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false;
|
||||||
|
if (perm0.getNumElements() != perm1.getNumElements()) return false;
|
||||||
|
|
||||||
|
SmallVector<int64_t, 8> perm0_values;
|
||||||
|
for (auto value : perm0.getIntValues())
|
||||||
|
perm0_values.push_back(value.getSExtValue());
|
||||||
|
|
||||||
|
SmallVector<int64_t, 8> perm1_values;
|
||||||
|
for (auto value : perm1.getIntValues())
|
||||||
|
perm1_values.push_back(value.getSExtValue());
|
||||||
|
|
||||||
|
for (int i = 0; i < perm0_values.size(); ++i) {
|
||||||
|
if (perm0_values[perm1_values[i]] != i) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
static bool IsUnknownDimOrRank(int64_t dim_or_rank) {
|
||||||
return dim_or_rank == -1;
|
return dim_or_rank == -1;
|
||||||
}
|
}
|
||||||
@ -2723,23 +2743,46 @@ void TransposeOp::build(Builder *builder, OperationState &result, Value x,
|
|||||||
perm);
|
perm);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
namespace {
|
||||||
auto const_perm = dyn_cast_or_null<TF::ConstOp>(perm().getDefiningOp());
|
|
||||||
|
|
||||||
if (!const_perm) {
|
OpFoldResult FoldIdentityTranspose(TransposeOp op) {
|
||||||
return {};
|
auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
|
||||||
}
|
if (!const_perm) return {};
|
||||||
|
|
||||||
auto const_value = const_perm.value();
|
auto const_value = const_perm.value();
|
||||||
|
|
||||||
const auto &elements = const_value.getValues<APInt>();
|
const auto &elements = const_value.getValues<APInt>();
|
||||||
|
|
||||||
for (auto it : llvm::enumerate(elements)) {
|
for (auto it : llvm::enumerate(elements)) {
|
||||||
if (it.index() != it.value()) {
|
if (it.index() != it.value()) return {};
|
||||||
return {};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return x();
|
return op.x();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult FoldCancellableTranspose(TransposeOp op) {
|
||||||
|
// Operand is a TransposeOp.
|
||||||
|
auto transpose = dyn_cast_or_null<TF::TransposeOp>(op.x().getDefiningOp());
|
||||||
|
if (!transpose) return {};
|
||||||
|
|
||||||
|
// Permutations defined by constant operations.
|
||||||
|
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
|
||||||
|
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
|
||||||
|
if (!perm0 || !perm1) return {};
|
||||||
|
|
||||||
|
// With permutation indices that cancel each other
|
||||||
|
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
|
||||||
|
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
|
||||||
|
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
|
||||||
|
|
||||||
|
return transpose.x();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
if (auto folded = FoldIdentityTranspose(*this)) return folded;
|
||||||
|
if (auto folded = FoldCancellableTranspose(*this)) return folded;
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -383,6 +383,28 @@ func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32
|
|||||||
// CHECK: return %1
|
// CHECK: return %1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @cancellableTranspose
|
||||||
|
func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
%1 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
|
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
|
||||||
|
|
||||||
|
return %3 : tensor<1x4x4x8xf32>
|
||||||
|
// CHECK: return %arg0
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @nonCancellableTranspose
|
||||||
|
func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
%1 = "tf.Const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||||
|
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||||
|
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<4x1x4x8xf32>
|
||||||
|
|
||||||
|
return %3 : tensor<4x1x4x8xf32>
|
||||||
|
// CHECK: return %3
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @addN
|
// CHECK-LABEL: func @addN
|
||||||
func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return %arg0
|
// CHECK: return %arg0
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
// RUN: tf-opt %s -tf-layout-optimization=force-data-format=NCHW -verify-diagnostics | FileCheck %s --dump-input=always
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @transposeBiasAdd
|
||||||
|
func @transposeBiasAdd(%arg0: tensor<1x8x4x4xf32>, %arg1: tensor<8xf32>) -> tensor<1x8x4x4xf32> {
|
||||||
|
|
||||||
|
// Convert input: NCHW -> NHWC
|
||||||
|
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||||
|
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x8x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x8xf32>
|
||||||
|
|
||||||
|
// Compute in NHWC
|
||||||
|
%2 = "tf.BiasAdd"(%1, %arg1) {data_format = "NHWC"} : (tensor<1x4x4x8xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
|
||||||
|
|
||||||
|
// Convert result back: NHWC -> NCHW
|
||||||
|
%3 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||||
|
%4 = "tf.Transpose"(%2, %3) : (tensor<1x4x4x8xf32>, tensor<4xi64>) -> tensor<1x8x4x4xf32>
|
||||||
|
|
||||||
|
// Check that BiasAdd computed in NCHW format, and all redundant transpose
|
||||||
|
// operations removed from the function.
|
||||||
|
|
||||||
|
// CHECK: %[[BIAS_ADD:[0-9]*]] = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} {{.*}} tensor<1x8x4x4xf32>
|
||||||
|
// CHECK: return %[[BIAS_ADD]]
|
||||||
|
|
||||||
|
return %4 : tensor<1x8x4x4xf32>
|
||||||
|
}
|
@ -18,7 +18,9 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Builders.h" // TF:llvm-project
|
#include "mlir/IR/Builders.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Function.h" // TF:llvm-project
|
#include "mlir/IR/Function.h" // TF:llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
#include "mlir/Pass/Pass.h" // TF:llvm-project
|
||||||
|
#include "mlir/Pass/PassManager.h" // TF:llvm-project
|
||||||
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
#include "mlir/Pass/PassRegistry.h" // TF:llvm-project
|
||||||
|
#include "mlir/Transforms/Passes.h" // TF:llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "tf-layout-optimization"
|
#define DEBUG_TYPE "tf-layout-optimization"
|
||||||
@ -28,11 +30,25 @@ namespace TF {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Layout optimization pipeline composes layout assignment and move transposes
|
||||||
|
// passes to pick the optimal layout for all layout sensitive operations, and
|
||||||
|
// cancel all redundant transposes.
|
||||||
|
struct LayoutOptimizationPipelineOptions
|
||||||
|
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
|
||||||
|
Option<std::string> force_data_format{
|
||||||
|
*this, "force-data-format",
|
||||||
|
llvm::cl::desc("Force data format for all layout sensitive ops")};
|
||||||
|
};
|
||||||
|
|
||||||
// LayoutAssignmentPass assigns optimal data layout (data format) for all
|
// LayoutAssignmentPass assigns optimal data layout (data format) for all
|
||||||
// layout sensitive operations.
|
// layout sensitive operations.
|
||||||
class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
|
class LayoutAssignmentPass : public FunctionPass<LayoutAssignmentPass> {
|
||||||
public:
|
public:
|
||||||
LayoutAssignmentPass() = default;
|
LayoutAssignmentPass() = default;
|
||||||
|
explicit LayoutAssignmentPass(const std::string& force_data_format) {
|
||||||
|
force_data_format_ = force_data_format;
|
||||||
|
}
|
||||||
|
|
||||||
LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
|
LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
|
||||||
|
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
@ -52,6 +68,7 @@ class MoveTransposesPass : public FunctionPass<MoveTransposesPass> {
|
|||||||
enum class Direction { kBegin, kEnd };
|
enum class Direction { kBegin, kEnd };
|
||||||
|
|
||||||
MoveTransposesPass() = default;
|
MoveTransposesPass() = default;
|
||||||
|
explicit MoveTransposesPass(Direction direction) { direction_ = direction; }
|
||||||
MoveTransposesPass(const MoveTransposesPass& pass) {}
|
MoveTransposesPass(const MoveTransposesPass& pass) {}
|
||||||
|
|
||||||
void runOnFunction() final;
|
void runOnFunction() final;
|
||||||
@ -356,6 +373,30 @@ void MoveTransposesPass::runOnFunction() {
|
|||||||
MoveTransposeAfter(op, &work_list);
|
MoveTransposeAfter(op, &work_list);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func.walk([&](TransposeOp transpose) {
|
||||||
|
OpBuilder builder(transpose);
|
||||||
|
SmallVector<Value, 1> fold_result;
|
||||||
|
if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
|
||||||
|
assert(fold_result.size() == 1);
|
||||||
|
transpose.replaceAllUsesWith(fold_result[0]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void CreateLayoutOptimizationPipeline(
|
||||||
|
OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference.
|
||||||
|
const LayoutOptimizationPipelineOptions& options) {
|
||||||
|
using Direction = MoveTransposesPass::Direction;
|
||||||
|
|
||||||
|
// Assign optimal layout for layout sensitive ops.
|
||||||
|
pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
|
||||||
|
|
||||||
|
// Move transposes to the beginning of the block and try to fold them.
|
||||||
|
pm.addPass(std::make_unique<MoveTransposesPass>(Direction::kBegin));
|
||||||
|
|
||||||
|
// Move transposes to the end of the block and try to fold them.
|
||||||
|
pm.addPass(std::make_unique<MoveTransposesPass>(Direction::kEnd));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -365,5 +406,11 @@ static PassRegistration<LayoutAssignmentPass> layout_assignment(
|
|||||||
static PassRegistration<MoveTransposesPass> move_transposes(
|
static PassRegistration<MoveTransposesPass> move_transposes(
|
||||||
"tf-move-transposes", "Move transposes pass");
|
"tf-move-transposes", "Move transposes pass");
|
||||||
|
|
||||||
|
static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
|
||||||
|
pipeline("tf-layout-optimization",
|
||||||
|
"Assigns optimal data layout to all layout sensitive operations "
|
||||||
|
"and cancel redundant transpose operations.",
|
||||||
|
CreateLayoutOptimizationPipeline);
|
||||||
|
|
||||||
} // namespace TF
|
} // namespace TF
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
Loading…
x
Reference in New Issue
Block a user