[XLA:CPU] Teach dot_op_emitter how to tile&vectorize linalg matmuls

And turn them on by default.

This is on-par with the existing emitter, sometimes better and unlocks more
potential. The strategy classes are duplicated right now, but I expect them to
graduate to mlir core soon.

I'm planning to remove the custom LLVM IR emitters if this turns out to be
stable enough.

PiperOrigin-RevId: 320117625
Change-Id: I3580df9990ca2a022a49327fa819c2086fd1e2ed
This commit is contained in:
A. Unique TensorFlower 2020-07-07 20:59:03 -07:00 committed by TensorFlower Gardener
parent 82e12bf387
commit 3dda4182aa
8 changed files with 14 additions and 541 deletions

View File

@ -471,7 +471,6 @@ cc_library(
":cpu_runtime",
":ir_emission_utils",
":mlir_emitter",
":mlir_matmul_codegen_strategy",
":target_machine_features",
":tiled_dot_emitter",
":vector_support_library",
@ -1103,33 +1102,12 @@ cc_library(
"@llvm-project//llvm:Core",
"@llvm-project//llvm:IPO",
"@llvm-project//llvm:Linker",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:LinalgToLLVM",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TargetLLVMIR",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorToLLVM",
],
)
cc_library(
name = "mlir_matmul_codegen_strategy",
srcs = ["mlir_matmul_codegen_strategy.cc"],
hdrs = ["mlir_matmul_codegen_strategy.h"],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:VectorOps",
"@llvm-project//mlir:VectorToSCF",
],
)

View File

@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
const char* const kXlaForceEnableExperimentalLlvmIrGemm =
"xla_force_enable_experimental_llvm_ir_gemm";
const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot";
const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
} // namespace
@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
}
bool UseLinalgForDot(const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
return extra_options_map.count(kXlaUseLinalgForDot) > 0;
}
static absl::string_view RemoveSuffix(absl::string_view str,
absl::string_view suffix) {
CHECK_GE(str.size(), suffix.size());

View File

@ -31,12 +31,10 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
@ -204,20 +202,6 @@ class DotOpEmitter {
.value_or(kDefaultTileSize);
}
std::array<int64_t, 3> GetMlirGemmTileSize() const {
// Tile by 4 x registers x register size. This was picked by running
// small matmuls on Haswell and Skylake. There's a lot of room for
// improvement here.
constexpr int64_t kDefaultTileSizeForM = 4;
int64_t elements_per_register =
target_machine_features_.vector_register_num_elements(
*b_->GetInsertBlock()->getParent(),
dot_info_.result_shape.element_type());
int64_t num_registers = target_machine_features_.vector_register_count(
*b_->GetInsertBlock()->getParent());
return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
}
DotInfo dot_info_;
string dot_hlo_name_;
const llvm_ir::IrArray& target_array_;
@ -266,7 +250,6 @@ Status DotOpEmitter::EmitLinalgMatmul() {
absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
dot_info_.lhs_shape.ToString(true), "_",
dot_info_.rhs_shape.ToString(true));
return EmitMlirFuncAndCall(
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
@ -276,27 +259,6 @@ Status DotOpEmitter::EmitLinalgMatmul() {
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
mlir::ValueRange{b, c, a});
mlir::edsc::intrinsics::std_ret();
mlir::linalg::LinalgTilingOptions tilingOptions;
tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
int64 alignment =
target_machine_features_.minimum_alignment_for_allocation(
ShapeUtil::ByteSizeOf(dot_info_.result_shape));
mlir_strategy::MatmulCodegenStrategy strategy;
strategy.tile<mlir::linalg::MatmulOp>(tilingOptions)
.promote<mlir::linalg::MatmulOp>(
mlir::linalg::LinalgPromotionOptions()
.setAlignment(alignment)
.setUseFullTileBuffersByDefault(true)
.setUseAlloca(true))
.vectorize<mlir::linalg::MatmulOp>()
.setVectorTransformsOptions(
mlir::vector::VectorTransformsOptions()
.setVectorTransformsOptions(
mlir::vector::VectorContractLowering::OuterProduct))
.setVectorTransferToSCFOptions(
mlir::VectorTransferToSCFOptions().setUnroll(true));
strategy.transform(function);
});
}
@ -986,8 +948,7 @@ DotImplementationStrategy GetDotImplementationStrategy(
if (IsAlignedGemm(dot_info, target_machine_features)) {
if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
return primitive_util::IsFloatingPointType(
dot_info.result_shape.element_type())
return options::UseLinalgForDot(config)
? DotImplementationStrategy::kLinalgMatmul
: DotImplementationStrategy::kTiledLlvmIrGemm;
}

View File

@ -17,14 +17,14 @@ limitations under the License.
#include "llvm/Linker/Linker.h"
#include "llvm/Transforms/IPO/Internalize.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Target/LLVMIR.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
namespace xla {
@ -35,9 +35,9 @@ namespace {
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
mlir::PassManager manager(module->getContext());
manager.addPass(mlir::createConvertLinalgToLoopsPass());
manager.addPass(mlir::createLowerAffinePass());
manager.addPass(mlir::createLowerToCFGPass());
manager.addPass(mlir::createConvertLinalgToLLVMPass());
manager.addPass(mlir::createConvertVectorToLLVMPass());
manager.addPass(mlir::createLowerToLLVMPass());
CHECK(succeeded(manager.run(*module)));
return mlir::translateModuleToLLVMIR(*module);
}

View File

@ -1,269 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Dominance.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
// TODO(kramerb): Remove this once strategy is in mlir core.
using namespace mlir; // NOLINT
using namespace mlir::linalg; // NOLINT
#define DEBUG_TYPE "matmul-codegen-strategy"
namespace xla {
namespace cpu {
namespace mlir_strategy {
//===----------------------------------------------------------------------===//
// TODO: Cleanup and upstream these to go into core. Please ignore for now !
//===----------------------------------------------------------------------===//
static void hoistRedundantCopies(FuncOp func) {
bool changed = true;
while (changed) {
changed = false;
func.walk([&](linalg::FillOp op) {
auto loop = op.getParentOfType<scf::ForOp>();
if (!loop) return;
for (auto operand : op.getOperands())
if (!loop.isDefinedOutsideOfLoop(operand)) return;
// Hoist fill before.
op.getOperation()->moveBefore(loop);
changed = true;
});
func.walk([&](linalg::CopyOp op) {
auto loop = op.getParentOfType<scf::ForOp>();
if (!loop) return;
for (auto operand : op.getOperands())
if (!loop.isDefinedOutsideOfLoop(operand)) return;
Value sourceView = op.getInput(0);
while (auto subViewOp = sourceView.getDefiningOp<SubViewOp>())
sourceView = subViewOp.getViewSource();
// Source traces back to a block argument.
if (sourceView.isa<BlockArgument>()) {
op.getOperation()->moveBefore(loop);
} else {
assert(sourceView.getDefiningOp<ViewOp>() ||
sourceView.getDefiningOp<AllocOp>() ||
sourceView.getDefiningOp<AllocaOp>());
op.getOperation()->moveAfter(loop);
}
changed = true;
});
}
}
/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing:
/// `%lb + %step * new_dim` where
/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an
/// AffineDimExpr depending on whether the value is constant or not.
/// 2. the AffineExpr for %step is either an AffineConstantExpr or an
/// AffineSymbolExpr depending on whether the value is constant or not.
///
static void substitute(scf::ForOp forOp, SmallVectorImpl<AffineExpr> &exprs,
SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &symbols) {
MLIRContext *ctx = forOp.getContext();
auto lbConstant = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx)
: getAffineDimExpr(dims.size(), ctx);
auto stepConstant = forOp.step().getDefiningOp<ConstantIndexOp>();
AffineExpr step = stepConstant
? getAffineConstantExpr(stepConstant.getValue(), ctx)
: getAffineSymbolExpr(symbols.size(), ctx);
if (!lbConstant) dims.push_back(forOp.lowerBound());
if (!stepConstant) symbols.push_back(forOp.step());
exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx));
auto ubConstant = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx)
: getAffineDimExpr(dims.size(), ctx);
if (!ubConstant) dims.push_back(forOp.upperBound());
exprs.push_back(ub);
dims.push_back(forOp.getInductionVar());
}
/// Traverse the .
static void substitute(AffineMinOp minOp, SmallVectorImpl<AffineExpr> &exprs,
SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &symbols) {
MLIRContext *ctx = minOp.getContext();
for (Value v : minOp.getDimOperands()) {
if (auto forOp = scf::getForInductionVarOwner(v)) {
substitute(forOp, exprs, dims, symbols);
continue;
}
if (auto parentMinOp = v.getDefiningOp<AffineMinOp>()) {
substitute(parentMinOp, exprs, dims, symbols);
continue;
}
exprs.push_back(getAffineDimExpr(dims.size(), ctx));
dims.push_back(v);
}
}
/// Perform folding of chains of AffineMinOp.
struct AffineMinCanonicalizationPattern : public OpRewritePattern<AffineMinOp> {
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineMinOp minOp,
PatternRewriter &rewriter) const override;
};
LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite(
AffineMinOp minOp, PatternRewriter &rewriter) const {
LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: "
<< *minOp.getOperation() << "\n");
int64_t min = std::numeric_limits<int64_t>::max();
for (auto e : minOp.map().getResults())
if (auto cstExpr = e.dyn_cast<AffineConstantExpr>())
min = std::min(min, cstExpr.getValue());
if (min == std::numeric_limits<int64_t>::max()) return failure();
SmallVector<AffineExpr, 4> exprs;
SmallVector<Value, 4> dims, symbols;
substitute(minOp, exprs, dims, symbols);
SmallVector<Value, 4> operands = dims;
operands.append(symbols.begin(), symbols.end());
MLIRContext *ctx = minOp.getContext();
auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx);
LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n");
SmallVector<AffineExpr, 4> modExprs;
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx)
modExprs.push_back(getAffineDimExpr(idx, ctx) % min);
map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map);
canonicalizeMapAndOperands(&map, &operands);
map = simplifyAffineMap(map);
LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n";
llvm::interleaveComma(operands, llvm::dbgs()));
if (!llvm::all_of(map.getResults(), [](AffineExpr e) {
if (auto cst = e.dyn_cast<AffineConstantExpr>())
return cst.getValue() == 0;
return false;
}))
return failure();
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, min);
return success();
}
//===----------------------------------------------------------------------===//
// END TODO
//===----------------------------------------------------------------------===//
void MatmulCodegenStrategy::transform(FuncOp func) const {
MLIRContext *context = func.getContext();
// Emplace patterns one at a time while also maintaining a simple chained
// state transition.
unsigned stepCount = 0;
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
auto zeroState = Identifier::get(std::to_string(stepCount), context);
auto currentState = zeroState;
for (auto &t : transformation_sequence) {
auto nextState = Identifier::get(std::to_string(++stepCount), context);
auto marker = (currentState == zeroState)
? linalg::LinalgMarker({}, nextState)
: linalg::LinalgMarker(currentState, nextState);
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
currentState = nextState;
}
OwningRewritePatternList stage2Patterns =
linalg::getLinalgTilingCanonicalizationPatterns(context);
stage2Patterns.insert<AffineMinCanonicalizationPattern>(context);
auto stage3Transforms = [](Operation *op) {
// Some of these may be too aggressive as a stage 3 that is applied on each
// stage 1 application and may have to be split out to post staged patterns
// application (in which case they could just be passes, TBD).
PassManager pm(op->getContext());
pm.addPass(createLoopInvariantCodeMotionPass());
if (failed(pm.run(op->getParentOfType<ModuleOp>())))
llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
promoteSingleIterationLoops(cast<FuncOp>(op));
hoistViewAllocOps(cast<FuncOp>(op));
hoistRedundantVectorTransfers(cast<FuncOp>(op));
hoistRedundantCopies(cast<FuncOp>(op));
return success();
};
linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
stage3Transforms);
//===--------------------------------------------------------------------===//
// Post staged patterns transforms
//===--------------------------------------------------------------------===//
// Programmatic controlled lowering of vector.contract only.
OwningRewritePatternList vectorContractLoweringPatterns;
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
vector_transforms_options, context);
applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns);
// Programmatic controlled lowering of vector.transfer only.
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
vector_to_scf_options);
applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns);
}
} // namespace mlir_strategy
} // namespace cpu
} // namespace xla

View File

@ -1,188 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
// TODO(kramerb): Remove this once strategy is in mlir core.
namespace xla {
namespace cpu {
namespace mlir_strategy {
/// Abstract Transformation class applied in a sequence that also handles state
/// through markers.
struct Transformation {
virtual ~Transformation() = default;
virtual mlir::OwningRewritePatternList buildRewritePatterns(
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0;
mlir::linalg::LinalgMarker marker;
};
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Tile<LinalgOpType>`with the appropriate `options`.
// TODO: variadic LinalgOpTypes.
template <typename LinalgOpType>
struct Tile : public Transformation {
explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {}
mlir::OwningRewritePatternList buildRewritePatterns(
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
mlir::OwningRewritePatternList tiling_patterns;
tiling_patterns.insert<mlir::linalg::LinalgTilingPattern<LinalgOpType>>(
context, options, m);
return tiling_patterns;
}
private:
mlir::linalg::LinalgTilingOptions options;
};
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Promote<LinalgOpType>`with the appropriate `options`.
// TODO: variadic LinalgOpTypes.
template <typename LinalgOpType>
struct Promote : public Transformation {
explicit Promote(mlir::linalg::LinalgPromotionOptions options)
: options(options) {}
mlir::OwningRewritePatternList buildRewritePatterns(
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
mlir::OwningRewritePatternList promotion_patterns;
promotion_patterns
.insert<mlir::linalg::LinalgPromotionPattern<LinalgOpType>>(context,
options, m);
return promotion_patterns;
}
private:
mlir::linalg::LinalgPromotionOptions options;
};
/// Vectorization transformation enqueues a particular stage-1 pattern for
/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
/// transfer rewrite forwarding patterns.
// TODO: variadic LinalgOpTypes.
template <typename LinalgOpType>
struct Vectorize : public Transformation {
mlir::OwningRewritePatternList buildRewritePatterns(
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
mlir::OwningRewritePatternList vectorization_patterns;
// FillOp may interfere with forwarding patterns atm, so we bump up the
// priority of LinalgCopyVTRForwardingPattern /
// LinalgCopyVTWForwardingPattern.
vectorization_patterns
.insert<mlir::linalg::LinalgVectorizationPattern<LinalgOpType>>(context,
m);
vectorization_patterns.insert<mlir::linalg::LinalgCopyVTRForwardingPattern,
mlir::linalg::LinalgCopyVTWForwardingPattern>(
context,
/*benefit=*/2);
return vectorization_patterns;
}
};
/// Matmul-specific strategy object controls how a linalg.matmul is
/// progressively lowered.
/// The strategy uses a 3-level staged patterns strategy which allows ordering
/// transformations by using the Linalg `applyStagedPatterns` function, where:
/// 1. The first stage consists of the successive `tile`, `promote` and
/// `vectorize` patterns, applied sequentially.
/// 2. The second stage consists of common local canonicalization patterns
/// that are applied eagerly after each stage-1 pattern.
/// 3. the third stage consists of more global transformation, also applied
/// eagerly, after all stage-2 patterns. Such more global transformations
struct MatmulCodegenStrategy {
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
/// `options`.
template <typename LinalgOpType>
MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) {
transformation_sequence.emplace_back(new Tile<LinalgOpType>(options));
return *this;
}
/// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
/// with tiling `options`.
template <typename LinalgOpType>
MatmulCodegenStrategy &tileIf(bool b,
mlir::linalg::LinalgTilingOptions options) {
return b ? tile<LinalgOpType>(options) : *this;
}
/// Append a pattern to add a level of promotion for `LinalgOpType` with
/// promotion `options`.
template <typename LinalgOpType>
MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) {
transformation_sequence.emplace_back(new Promote<LinalgOpType>(options));
return *this;
}
/// Conditionally append a pattern to add a level of promotion for
/// `LinalgOpType` with promotion `options`.
template <typename LinalgOpType>
MatmulCodegenStrategy &promoteIf(
bool b, mlir::linalg::LinalgPromotionOptions options) {
return b ? promote<LinalgOpType>(options) : *this;
return *this;
}
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
template <typename LinalgOpType>
MatmulCodegenStrategy &vectorize() {
transformation_sequence.emplace_back(new Vectorize<LinalgOpType>());
return *this;
}
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
/// operation.
template <typename LinalgOpType>
MatmulCodegenStrategy &vectorizeIf(bool b) {
return b ? vectorize<LinalgOpType>() : *this;
return *this;
}
/// Configure the post staged-patterns late vector transformations.
MatmulCodegenStrategy &setVectorTransformsOptions(
mlir::vector::VectorTransformsOptions options) {
vector_transforms_options = options;
return *this;
}
/// Configure the post staged-patterns late vector.transfer to scf conversion.
MatmulCodegenStrategy &setVectorTransferToSCFOptions(
mlir::VectorTransferToSCFOptions options) {
vector_to_scf_options = options;
return *this;
}
/// Apply the transformation patterns in sequence with cleanup transformations
/// interleaved.
void transform(mlir::FuncOp func) const;
private:
mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const;
mlir::vector::VectorTransformsOptions vector_transforms_options;
mlir::VectorTransferToSCFOptions vector_to_scf_options;
llvm::SmallVector<std::unique_ptr<Transformation>, 4> transformation_sequence;
};
} // namespace mlir_strategy
} // namespace cpu
} // namespace xla
#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_

View File

@ -52,12 +52,6 @@ class TargetMachineFeatures {
virtual int vector_register_num_elements(const llvm::Function& function,
PrimitiveType type) const = 0;
// Return the number of vector registers. We need to pass in
// "function" since llvm functions can contain annotations for specializing
// them to specific micro-architectures (though currently XLA does not use
// this functionality).
virtual int vector_register_count(const llvm::Function& function) const = 0;
// Returns the minimum alignment for a buffer of size size_bytes.
virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0;
@ -90,12 +84,6 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {
(primitive_util::BitWidth(type) / 8);
}
int vector_register_count(const llvm::Function& function) const override {
llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
return static_cast<int>(tti->getNumberOfRegisters(
tti->getRegisterClassForType(/*Vector=*/true)));
}
int64 minimum_alignment_for_allocation(int64 size_bytes) const override;
private:

View File

@ -44,10 +44,6 @@ class TargetMachineFeaturesWithFakeAlignmentLogic
LOG(FATAL) << "Unexpected call to " << __func__;
}
int vector_register_count(const llvm::Function& function) const override {
LOG(FATAL) << "Unexpected call to " << __func__;
}
int64 minimum_alignment_for_allocation(int64 size_bytes) const override {
return fake_alignment_logic_(size_bytes);
}