[XLA:CPU] Teach dot_op_emitter how to tile&vectorize linalg matmuls
And turn them on by default. This is on-par with the existing emitter, sometimes better and unlocks more potential. The strategy classes are duplicated right now, but I expect them to graduate to mlir core soon. I'm planning to remove the custom LLVM IR emitters if this turns out to be stable enough. PiperOrigin-RevId: 320117625 Change-Id: I3580df9990ca2a022a49327fa819c2086fd1e2ed
This commit is contained in:
parent
82e12bf387
commit
3dda4182aa
@ -471,7 +471,6 @@ cc_library(
|
||||
":cpu_runtime",
|
||||
":ir_emission_utils",
|
||||
":mlir_emitter",
|
||||
":mlir_matmul_codegen_strategy",
|
||||
":target_machine_features",
|
||||
":tiled_dot_emitter",
|
||||
":vector_support_library",
|
||||
@ -1103,33 +1102,12 @@ cc_library(
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:IPO",
|
||||
"@llvm-project//llvm:Linker",
|
||||
"@llvm-project//mlir:CFGTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LLVMTransforms",
|
||||
"@llvm-project//mlir:LinalgToLLVM",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:TargetLLVMIR",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//mlir:VectorToLLVM",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_matmul_codegen_strategy",
|
||||
srcs = ["mlir_matmul_codegen_strategy.cc"],
|
||||
hdrs = ["mlir_matmul_codegen_strategy.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//mlir:VectorOps",
|
||||
"@llvm-project//mlir:VectorToSCF",
|
||||
],
|
||||
)
|
||||
|
@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
|
||||
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
|
||||
const char* const kXlaForceEnableExperimentalLlvmIrGemm =
|
||||
"xla_force_enable_experimental_llvm_ir_gemm";
|
||||
const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot";
|
||||
const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
|
||||
|
||||
} // namespace
|
||||
@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
|
||||
return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
|
||||
}
|
||||
|
||||
bool UseLinalgForDot(const HloModuleConfig& config) {
|
||||
const auto& extra_options_map =
|
||||
config.debug_options().xla_backend_extra_options();
|
||||
return extra_options_map.count(kXlaUseLinalgForDot) > 0;
|
||||
}
|
||||
|
||||
static absl::string_view RemoveSuffix(absl::string_view str,
|
||||
absl::string_view suffix) {
|
||||
CHECK_GE(str.size(), suffix.size());
|
||||
|
@ -31,12 +31,10 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
|
||||
@ -204,20 +202,6 @@ class DotOpEmitter {
|
||||
.value_or(kDefaultTileSize);
|
||||
}
|
||||
|
||||
std::array<int64_t, 3> GetMlirGemmTileSize() const {
|
||||
// Tile by 4 x registers x register size. This was picked by running
|
||||
// small matmuls on Haswell and Skylake. There's a lot of room for
|
||||
// improvement here.
|
||||
constexpr int64_t kDefaultTileSizeForM = 4;
|
||||
int64_t elements_per_register =
|
||||
target_machine_features_.vector_register_num_elements(
|
||||
*b_->GetInsertBlock()->getParent(),
|
||||
dot_info_.result_shape.element_type());
|
||||
int64_t num_registers = target_machine_features_.vector_register_count(
|
||||
*b_->GetInsertBlock()->getParent());
|
||||
return {{kDefaultTileSizeForM, num_registers, elements_per_register}};
|
||||
}
|
||||
|
||||
DotInfo dot_info_;
|
||||
string dot_hlo_name_;
|
||||
const llvm_ir::IrArray& target_array_;
|
||||
@ -266,7 +250,6 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
||||
absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
|
||||
dot_info_.lhs_shape.ToString(true), "_",
|
||||
dot_info_.rhs_shape.ToString(true));
|
||||
|
||||
return EmitMlirFuncAndCall(
|
||||
mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
|
||||
operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
|
||||
@ -276,27 +259,6 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
||||
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
|
||||
mlir::ValueRange{b, c, a});
|
||||
mlir::edsc::intrinsics::std_ret();
|
||||
|
||||
mlir::linalg::LinalgTilingOptions tilingOptions;
|
||||
tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize());
|
||||
int64 alignment =
|
||||
target_machine_features_.minimum_alignment_for_allocation(
|
||||
ShapeUtil::ByteSizeOf(dot_info_.result_shape));
|
||||
mlir_strategy::MatmulCodegenStrategy strategy;
|
||||
strategy.tile<mlir::linalg::MatmulOp>(tilingOptions)
|
||||
.promote<mlir::linalg::MatmulOp>(
|
||||
mlir::linalg::LinalgPromotionOptions()
|
||||
.setAlignment(alignment)
|
||||
.setUseFullTileBuffersByDefault(true)
|
||||
.setUseAlloca(true))
|
||||
.vectorize<mlir::linalg::MatmulOp>()
|
||||
.setVectorTransformsOptions(
|
||||
mlir::vector::VectorTransformsOptions()
|
||||
.setVectorTransformsOptions(
|
||||
mlir::vector::VectorContractLowering::OuterProduct))
|
||||
.setVectorTransferToSCFOptions(
|
||||
mlir::VectorTransferToSCFOptions().setUnroll(true));
|
||||
strategy.transform(function);
|
||||
});
|
||||
}
|
||||
|
||||
@ -986,8 +948,7 @@ DotImplementationStrategy GetDotImplementationStrategy(
|
||||
|
||||
if (IsAlignedGemm(dot_info, target_machine_features)) {
|
||||
if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
|
||||
return primitive_util::IsFloatingPointType(
|
||||
dot_info.result_shape.element_type())
|
||||
return options::UseLinalgForDot(config)
|
||||
? DotImplementationStrategy::kLinalgMatmul
|
||||
: DotImplementationStrategy::kTiledLlvmIrGemm;
|
||||
}
|
||||
|
@ -17,14 +17,14 @@ limitations under the License.
|
||||
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Transforms/IPO/Internalize.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Target/LLVMIR.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
|
||||
namespace xla {
|
||||
@ -35,9 +35,9 @@ namespace {
|
||||
std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
|
||||
mlir::PassManager manager(module->getContext());
|
||||
manager.addPass(mlir::createConvertLinalgToLoopsPass());
|
||||
manager.addPass(mlir::createLowerAffinePass());
|
||||
manager.addPass(mlir::createLowerToCFGPass());
|
||||
manager.addPass(mlir::createConvertLinalgToLLVMPass());
|
||||
manager.addPass(mlir::createConvertVectorToLLVMPass());
|
||||
manager.addPass(mlir::createLowerToLLVMPass());
|
||||
CHECK(succeeded(manager.run(*module)));
|
||||
return mlir::translateModuleToLLVMIR(*module);
|
||||
}
|
||||
|
@ -1,269 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
|
||||
#include "mlir/IR/AffineExpr.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Dominance.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
|
||||
// TODO(kramerb): Remove this once strategy is in mlir core.
|
||||
|
||||
using namespace mlir; // NOLINT
|
||||
using namespace mlir::linalg; // NOLINT
|
||||
|
||||
#define DEBUG_TYPE "matmul-codegen-strategy"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace mlir_strategy {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: Cleanup and upstream these to go into core. Please ignore for now !
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void hoistRedundantCopies(FuncOp func) {
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
func.walk([&](linalg::FillOp op) {
|
||||
auto loop = op.getParentOfType<scf::ForOp>();
|
||||
if (!loop) return;
|
||||
|
||||
for (auto operand : op.getOperands())
|
||||
if (!loop.isDefinedOutsideOfLoop(operand)) return;
|
||||
|
||||
// Hoist fill before.
|
||||
op.getOperation()->moveBefore(loop);
|
||||
changed = true;
|
||||
});
|
||||
|
||||
func.walk([&](linalg::CopyOp op) {
|
||||
auto loop = op.getParentOfType<scf::ForOp>();
|
||||
if (!loop) return;
|
||||
|
||||
for (auto operand : op.getOperands())
|
||||
if (!loop.isDefinedOutsideOfLoop(operand)) return;
|
||||
|
||||
Value sourceView = op.getInput(0);
|
||||
while (auto subViewOp = sourceView.getDefiningOp<SubViewOp>())
|
||||
sourceView = subViewOp.getViewSource();
|
||||
|
||||
// Source traces back to a block argument.
|
||||
if (sourceView.isa<BlockArgument>()) {
|
||||
op.getOperation()->moveBefore(loop);
|
||||
} else {
|
||||
assert(sourceView.getDefiningOp<ViewOp>() ||
|
||||
sourceView.getDefiningOp<AllocOp>() ||
|
||||
sourceView.getDefiningOp<AllocaOp>());
|
||||
op.getOperation()->moveAfter(loop);
|
||||
}
|
||||
changed = true;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing:
|
||||
/// `%lb + %step * new_dim` where
|
||||
/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an
|
||||
/// AffineDimExpr depending on whether the value is constant or not.
|
||||
/// 2. the AffineExpr for %step is either an AffineConstantExpr or an
|
||||
/// AffineSymbolExpr depending on whether the value is constant or not.
|
||||
///
|
||||
static void substitute(scf::ForOp forOp, SmallVectorImpl<AffineExpr> &exprs,
|
||||
SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols) {
|
||||
MLIRContext *ctx = forOp.getContext();
|
||||
auto lbConstant = forOp.lowerBound().getDefiningOp<ConstantIndexOp>();
|
||||
AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx)
|
||||
: getAffineDimExpr(dims.size(), ctx);
|
||||
|
||||
auto stepConstant = forOp.step().getDefiningOp<ConstantIndexOp>();
|
||||
AffineExpr step = stepConstant
|
||||
? getAffineConstantExpr(stepConstant.getValue(), ctx)
|
||||
: getAffineSymbolExpr(symbols.size(), ctx);
|
||||
|
||||
if (!lbConstant) dims.push_back(forOp.lowerBound());
|
||||
if (!stepConstant) symbols.push_back(forOp.step());
|
||||
exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx));
|
||||
|
||||
auto ubConstant = forOp.upperBound().getDefiningOp<ConstantIndexOp>();
|
||||
AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx)
|
||||
: getAffineDimExpr(dims.size(), ctx);
|
||||
if (!ubConstant) dims.push_back(forOp.upperBound());
|
||||
exprs.push_back(ub);
|
||||
|
||||
dims.push_back(forOp.getInductionVar());
|
||||
}
|
||||
|
||||
/// Traverse the .
|
||||
static void substitute(AffineMinOp minOp, SmallVectorImpl<AffineExpr> &exprs,
|
||||
SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols) {
|
||||
MLIRContext *ctx = minOp.getContext();
|
||||
for (Value v : minOp.getDimOperands()) {
|
||||
if (auto forOp = scf::getForInductionVarOwner(v)) {
|
||||
substitute(forOp, exprs, dims, symbols);
|
||||
continue;
|
||||
}
|
||||
if (auto parentMinOp = v.getDefiningOp<AffineMinOp>()) {
|
||||
substitute(parentMinOp, exprs, dims, symbols);
|
||||
continue;
|
||||
}
|
||||
exprs.push_back(getAffineDimExpr(dims.size(), ctx));
|
||||
dims.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform folding of chains of AffineMinOp.
|
||||
struct AffineMinCanonicalizationPattern : public OpRewritePattern<AffineMinOp> {
|
||||
using OpRewritePattern<AffineMinOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AffineMinOp minOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite(
|
||||
AffineMinOp minOp, PatternRewriter &rewriter) const {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: "
|
||||
<< *minOp.getOperation() << "\n");
|
||||
|
||||
int64_t min = std::numeric_limits<int64_t>::max();
|
||||
for (auto e : minOp.map().getResults())
|
||||
if (auto cstExpr = e.dyn_cast<AffineConstantExpr>())
|
||||
min = std::min(min, cstExpr.getValue());
|
||||
if (min == std::numeric_limits<int64_t>::max()) return failure();
|
||||
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
SmallVector<Value, 4> dims, symbols;
|
||||
substitute(minOp, exprs, dims, symbols);
|
||||
|
||||
SmallVector<Value, 4> operands = dims;
|
||||
operands.append(symbols.begin(), symbols.end());
|
||||
|
||||
MLIRContext *ctx = minOp.getContext();
|
||||
auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n");
|
||||
|
||||
SmallVector<AffineExpr, 4> modExprs;
|
||||
for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx)
|
||||
modExprs.push_back(getAffineDimExpr(idx, ctx) % min);
|
||||
map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map);
|
||||
canonicalizeMapAndOperands(&map, &operands);
|
||||
map = simplifyAffineMap(map);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n";
|
||||
llvm::interleaveComma(operands, llvm::dbgs()));
|
||||
|
||||
if (!llvm::all_of(map.getResults(), [](AffineExpr e) {
|
||||
if (auto cst = e.dyn_cast<AffineConstantExpr>())
|
||||
return cst.getValue() == 0;
|
||||
return false;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, min);
|
||||
return success();
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// END TODO
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void MatmulCodegenStrategy::transform(FuncOp func) const {
|
||||
MLIRContext *context = func.getContext();
|
||||
// Emplace patterns one at a time while also maintaining a simple chained
|
||||
// state transition.
|
||||
unsigned stepCount = 0;
|
||||
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
||||
auto zeroState = Identifier::get(std::to_string(stepCount), context);
|
||||
auto currentState = zeroState;
|
||||
for (auto &t : transformation_sequence) {
|
||||
auto nextState = Identifier::get(std::to_string(++stepCount), context);
|
||||
auto marker = (currentState == zeroState)
|
||||
? linalg::LinalgMarker({}, nextState)
|
||||
: linalg::LinalgMarker(currentState, nextState);
|
||||
stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
|
||||
currentState = nextState;
|
||||
}
|
||||
|
||||
OwningRewritePatternList stage2Patterns =
|
||||
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
||||
stage2Patterns.insert<AffineMinCanonicalizationPattern>(context);
|
||||
|
||||
auto stage3Transforms = [](Operation *op) {
|
||||
// Some of these may be too aggressive as a stage 3 that is applied on each
|
||||
// stage 1 application and may have to be split out to post staged patterns
|
||||
// application (in which case they could just be passes, TBD).
|
||||
PassManager pm(op->getContext());
|
||||
pm.addPass(createLoopInvariantCodeMotionPass());
|
||||
if (failed(pm.run(op->getParentOfType<ModuleOp>())))
|
||||
llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
|
||||
promoteSingleIterationLoops(cast<FuncOp>(op));
|
||||
hoistViewAllocOps(cast<FuncOp>(op));
|
||||
hoistRedundantVectorTransfers(cast<FuncOp>(op));
|
||||
hoistRedundantCopies(cast<FuncOp>(op));
|
||||
return success();
|
||||
};
|
||||
linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
|
||||
stage3Transforms);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Post staged patterns transforms
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Programmatic controlled lowering of vector.contract only.
|
||||
OwningRewritePatternList vectorContractLoweringPatterns;
|
||||
vectorContractLoweringPatterns
|
||||
.insert<ContractionOpToOuterProductOpLowering,
|
||||
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
|
||||
vector_transforms_options, context);
|
||||
applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns);
|
||||
|
||||
// Programmatic controlled lowering of vector.transfer only.
|
||||
OwningRewritePatternList vectorToLoopsPatterns;
|
||||
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
|
||||
vector_to_scf_options);
|
||||
applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns);
|
||||
}
|
||||
|
||||
} // namespace mlir_strategy
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
@ -1,188 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
||||
#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
|
||||
// TODO(kramerb): Remove this once strategy is in mlir core.
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace mlir_strategy {
|
||||
|
||||
/// Abstract Transformation class applied in a sequence that also handles state
|
||||
/// through markers.
|
||||
struct Transformation {
|
||||
virtual ~Transformation() = default;
|
||||
virtual mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0;
|
||||
mlir::linalg::LinalgMarker marker;
|
||||
};
|
||||
|
||||
/// Promotion transformation enqueues a particular stage-1 pattern for
|
||||
/// `Tile<LinalgOpType>`with the appropriate `options`.
|
||||
// TODO: variadic LinalgOpTypes.
|
||||
template <typename LinalgOpType>
|
||||
struct Tile : public Transformation {
|
||||
explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {}
|
||||
|
||||
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||
mlir::OwningRewritePatternList tiling_patterns;
|
||||
tiling_patterns.insert<mlir::linalg::LinalgTilingPattern<LinalgOpType>>(
|
||||
context, options, m);
|
||||
return tiling_patterns;
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::linalg::LinalgTilingOptions options;
|
||||
};
|
||||
|
||||
/// Promotion transformation enqueues a particular stage-1 pattern for
|
||||
/// `Promote<LinalgOpType>`with the appropriate `options`.
|
||||
// TODO: variadic LinalgOpTypes.
|
||||
template <typename LinalgOpType>
|
||||
struct Promote : public Transformation {
|
||||
explicit Promote(mlir::linalg::LinalgPromotionOptions options)
|
||||
: options(options) {}
|
||||
|
||||
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||
mlir::OwningRewritePatternList promotion_patterns;
|
||||
promotion_patterns
|
||||
.insert<mlir::linalg::LinalgPromotionPattern<LinalgOpType>>(context,
|
||||
options, m);
|
||||
return promotion_patterns;
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::linalg::LinalgPromotionOptions options;
|
||||
};
|
||||
|
||||
/// Vectorization transformation enqueues a particular stage-1 pattern for
|
||||
/// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
|
||||
/// transfer rewrite forwarding patterns.
|
||||
// TODO: variadic LinalgOpTypes.
|
||||
template <typename LinalgOpType>
|
||||
struct Vectorize : public Transformation {
|
||||
mlir::OwningRewritePatternList buildRewritePatterns(
|
||||
mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override {
|
||||
mlir::OwningRewritePatternList vectorization_patterns;
|
||||
// FillOp may interfere with forwarding patterns atm, so we bump up the
|
||||
// priority of LinalgCopyVTRForwardingPattern /
|
||||
// LinalgCopyVTWForwardingPattern.
|
||||
vectorization_patterns
|
||||
.insert<mlir::linalg::LinalgVectorizationPattern<LinalgOpType>>(context,
|
||||
m);
|
||||
vectorization_patterns.insert<mlir::linalg::LinalgCopyVTRForwardingPattern,
|
||||
mlir::linalg::LinalgCopyVTWForwardingPattern>(
|
||||
context,
|
||||
/*benefit=*/2);
|
||||
return vectorization_patterns;
|
||||
}
|
||||
};
|
||||
|
||||
/// Matmul-specific strategy object controls how a linalg.matmul is
|
||||
/// progressively lowered.
|
||||
/// The strategy uses a 3-level staged patterns strategy which allows ordering
|
||||
/// transformations by using the Linalg `applyStagedPatterns` function, where:
|
||||
/// 1. The first stage consists of the successive `tile`, `promote` and
|
||||
/// `vectorize` patterns, applied sequentially.
|
||||
/// 2. The second stage consists of common local canonicalization patterns
|
||||
/// that are applied eagerly after each stage-1 pattern.
|
||||
/// 3. the third stage consists of more global transformation, also applied
|
||||
/// eagerly, after all stage-2 patterns. Such more global transformations
|
||||
struct MatmulCodegenStrategy {
|
||||
/// Append a pattern to add a level of tiling for `LinalgOpType` with tiling
|
||||
/// `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) {
|
||||
transformation_sequence.emplace_back(new Tile<LinalgOpType>(options));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of tiling for `LinalgOpType`
|
||||
/// with tiling `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &tileIf(bool b,
|
||||
mlir::linalg::LinalgTilingOptions options) {
|
||||
return b ? tile<LinalgOpType>(options) : *this;
|
||||
}
|
||||
/// Append a pattern to add a level of promotion for `LinalgOpType` with
|
||||
/// promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) {
|
||||
transformation_sequence.emplace_back(new Promote<LinalgOpType>(options));
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to add a level of promotion for
|
||||
/// `LinalgOpType` with promotion `options`.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &promoteIf(
|
||||
bool b, mlir::linalg::LinalgPromotionOptions options) {
|
||||
return b ? promote<LinalgOpType>(options) : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Append a pattern to rewrite `LinalgOpType` as a vector operation.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &vectorize() {
|
||||
transformation_sequence.emplace_back(new Vectorize<LinalgOpType>());
|
||||
return *this;
|
||||
}
|
||||
/// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
|
||||
/// operation.
|
||||
template <typename LinalgOpType>
|
||||
MatmulCodegenStrategy &vectorizeIf(bool b) {
|
||||
return b ? vectorize<LinalgOpType>() : *this;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector transformations.
|
||||
MatmulCodegenStrategy &setVectorTransformsOptions(
|
||||
mlir::vector::VectorTransformsOptions options) {
|
||||
vector_transforms_options = options;
|
||||
return *this;
|
||||
}
|
||||
/// Configure the post staged-patterns late vector.transfer to scf conversion.
|
||||
MatmulCodegenStrategy &setVectorTransferToSCFOptions(
|
||||
mlir::VectorTransferToSCFOptions options) {
|
||||
vector_to_scf_options = options;
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Apply the transformation patterns in sequence with cleanup transformations
|
||||
/// interleaved.
|
||||
void transform(mlir::FuncOp func) const;
|
||||
|
||||
private:
|
||||
mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const;
|
||||
|
||||
mlir::vector::VectorTransformsOptions vector_transforms_options;
|
||||
mlir::VectorTransferToSCFOptions vector_to_scf_options;
|
||||
llvm::SmallVector<std::unique_ptr<Transformation>, 4> transformation_sequence;
|
||||
};
|
||||
|
||||
} // namespace mlir_strategy
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_
|
@ -52,12 +52,6 @@ class TargetMachineFeatures {
|
||||
virtual int vector_register_num_elements(const llvm::Function& function,
|
||||
PrimitiveType type) const = 0;
|
||||
|
||||
// Return the number of vector registers. We need to pass in
|
||||
// "function" since llvm functions can contain annotations for specializing
|
||||
// them to specific micro-architectures (though currently XLA does not use
|
||||
// this functionality).
|
||||
virtual int vector_register_count(const llvm::Function& function) const = 0;
|
||||
|
||||
// Returns the minimum alignment for a buffer of size size_bytes.
|
||||
virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0;
|
||||
|
||||
@ -90,12 +84,6 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {
|
||||
(primitive_util::BitWidth(type) / 8);
|
||||
}
|
||||
|
||||
int vector_register_count(const llvm::Function& function) const override {
|
||||
llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function);
|
||||
return static_cast<int>(tti->getNumberOfRegisters(
|
||||
tti->getRegisterClassForType(/*Vector=*/true)));
|
||||
}
|
||||
|
||||
int64 minimum_alignment_for_allocation(int64 size_bytes) const override;
|
||||
|
||||
private:
|
||||
|
@ -44,10 +44,6 @@ class TargetMachineFeaturesWithFakeAlignmentLogic
|
||||
LOG(FATAL) << "Unexpected call to " << __func__;
|
||||
}
|
||||
|
||||
int vector_register_count(const llvm::Function& function) const override {
|
||||
LOG(FATAL) << "Unexpected call to " << __func__;
|
||||
}
|
||||
|
||||
int64 minimum_alignment_for_allocation(int64 size_bytes) const override {
|
||||
return fake_alignment_logic_(size_bytes);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user