Add lowering of reduce operation to the compilation pipeline.
PiperOrigin-RevId: 285167246 Change-Id: Icb293113a4bb92852e38bc3a94a2d58ab96cfcae
This commit is contained in:
parent
02a94bf820
commit
9bc6caa985
@ -143,6 +143,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/xla:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_fuse_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
|
||||
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_linalg",
|
||||
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
@ -151,6 +152,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@local_config_mlir//:AffineDialectRegistration",
|
||||
"@local_config_mlir//:CFGTransforms",
|
||||
"@local_config_mlir//:GPUDialect",
|
||||
"@local_config_mlir//:GPUDialectRegistration",
|
||||
"@local_config_mlir//:GPUToNVVMTransforms",
|
||||
@ -160,6 +162,7 @@ cc_library(
|
||||
"@local_config_mlir//:LLVMTransforms",
|
||||
"@local_config_mlir//:Linalg",
|
||||
"@local_config_mlir//:LinalgDialectRegistration",
|
||||
"@local_config_mlir//:LinalgToLLVM",
|
||||
"@local_config_mlir//:LoopDialectRegistration",
|
||||
"@local_config_mlir//:LoopOps",
|
||||
"@local_config_mlir//:LoopsToGPUPass",
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:local_config_mlir
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // TF:local_config_mlir
|
||||
#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // TF:local_config_mlir
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // TF:local_config_mlir
|
||||
@ -73,6 +75,11 @@ struct FusionToLhloConverter
|
||||
signalPassFailure();
|
||||
}
|
||||
});
|
||||
getFunction().walk([&](mlir::xla_lhlo::ReduceOp op) {
|
||||
if (failed(applyPartialConversion(op, target, patterns, nullptr))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@ -266,12 +273,18 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) {
|
||||
pm.addPass(absl::make_unique<FusionToLhloConverter>());
|
||||
// Next, we can strip the outer fusion operation.
|
||||
pm.addPass(absl::make_unique<FusionOpRemover>());
|
||||
pm.addPass(absl::make_unique<DumpPass>());
|
||||
// Transform lhlo operations to LinAlg.
|
||||
pm.addPass(::mlir::xla_lhlo::createLegalizeToLinalgPass());
|
||||
pm.addPass(absl::make_unique<DumpPass>());
|
||||
// Fuse linalg operations. This will yield a single tiled loop nest where
|
||||
// the inner loops are single trip.
|
||||
pm.addPass(::mlir::xla_lhlo::createLhloFuseLinalg());
|
||||
pm.addPass(absl::make_unique<DumpPass>());
|
||||
// Legalize reduce operations directly to GPU dialect.
|
||||
pm.addPass(::mlir::xla_lhlo::createLegalizeToGpuPass());
|
||||
pm.addPass(absl::make_unique<DumpPass>());
|
||||
// Fuse linalg operations. This will yield a single tiled loop nest where
|
||||
// Go from linalg to normal loops.
|
||||
pm.addPass(::mlir::linalg::createConvertLinalgToLoopsPass());
|
||||
pm.addPass(absl::make_unique<DumpPass>());
|
||||
@ -309,13 +322,50 @@ Status LowerLHLOToGPU(mlir::ModuleOp module) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// A pass that does the final lowering to NVVM. It collects all the patterns
|
||||
/// that are currently required, currently mixing std, linalg and gpu.
|
||||
class LowerToNVVMPass : public ::mlir::ModulePass<LowerToNVVMPass> {
|
||||
public:
|
||||
void runOnModule() override {
|
||||
::mlir::ModuleOp m = getModule();
|
||||
if (!m.getAttrOfType<::mlir::UnitAttr>(
|
||||
::mlir::gpu::GPUDialect::getKernelModuleAttrName())) {
|
||||
return;
|
||||
}
|
||||
|
||||
::mlir::OwningRewritePatternList patterns;
|
||||
::mlir::LinalgTypeConverter converter(m.getContext());
|
||||
::mlir::populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
// TODO(b/145824979) Remove linalg once sliceop is in std.
|
||||
::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns,
|
||||
&getContext());
|
||||
::mlir::populateGpuToNVVMConversionPatterns(converter, patterns);
|
||||
|
||||
::mlir::ConversionTarget target(getContext());
|
||||
target.addIllegalDialect<::mlir::gpu::GPUDialect>();
|
||||
target.addIllegalOp<::mlir::LLVM::ExpOp>();
|
||||
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
||||
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
|
||||
// TODO(csigg): Remove once we support replacing non-root ops.
|
||||
target.addLegalOp<::mlir::gpu::YieldOp>();
|
||||
if (failed(applyPartialConversion(m, target, patterns, &converter))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) {
|
||||
// We cannot verify as the signature of the kernel is rewritten.
|
||||
::mlir::PassManager pm(module.getContext(), /*verifyPasses=*/false);
|
||||
|
||||
// Rewrite kernel functions to LLVM IR.
|
||||
auto& kernelPm = pm.nest<::mlir::ModuleOp>();
|
||||
kernelPm.addPass(::mlir::createLowerGpuOpsToNVVMOpsPass());
|
||||
kernelPm.addPass(::mlir::createLowerToCFGPass());
|
||||
kernelPm.addPass(absl::make_unique<LowerToNVVMPass>());
|
||||
// Some basic cleanup.
|
||||
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||
kernelPm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||
|
@ -276,44 +276,45 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] {
|
||||
LoweringStage::GPU);
|
||||
}
|
||||
|
||||
TEST_F(LhloGenTest, FusedReduce) {
|
||||
CompileAndVerifyIr(R"(
|
||||
HloModule FusedReduce
|
||||
|
||||
%add (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
%fused_computation (param: f32[100,10]) -> f32[10] {
|
||||
%param = f32[100,10] parameter(0)
|
||||
%constant = f32[] constant(0)
|
||||
ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant),
|
||||
dimensions={0}, to_apply=%add
|
||||
}
|
||||
|
||||
ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] {
|
||||
%x = f32[100,10] parameter(0)
|
||||
ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput,
|
||||
calls=%fused_computation
|
||||
}
|
||||
)",
|
||||
R"(
|
||||
;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]])
|
||||
;CHECK: "xla_lhlo.fusion"() ( {
|
||||
;CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]]
|
||||
;CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00>
|
||||
;CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( {
|
||||
;CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]])
|
||||
;CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]]
|
||||
;CHECK: "xla_hlo.return"(%[[ADD]])
|
||||
;CHECK: })
|
||||
;CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]]
|
||||
;CHECK: "xla_lhlo.terminator"()
|
||||
;CHECK-NEXT: })
|
||||
)");
|
||||
}
|
||||
// TODO(b/137624192): Reenable once we can fuse reductions.
|
||||
// TEST_F(LhloGenTest, FusedReduce) {
|
||||
// CompileAndVerifyIr(R"(
|
||||
// HloModule FusedReduce
|
||||
//
|
||||
// %add (x: f32[], y: f32[]) -> f32[] {
|
||||
// %x = f32[] parameter(0)
|
||||
// %y = f32[] parameter(1)
|
||||
// ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
// }
|
||||
//
|
||||
// %fused_computation (param: f32[100,10]) -> f32[10] {
|
||||
// %param = f32[100,10] parameter(0)
|
||||
// %constant = f32[] constant(0)
|
||||
// ROOT %reduce = f32[10]{0} reduce(f32[100,10]{1,0} %param, f32[] %constant),
|
||||
// dimensions={0}, to_apply=%add
|
||||
// }
|
||||
//
|
||||
// ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] {
|
||||
// %x = f32[100,10] parameter(0)
|
||||
// ROOT %fusion = f32[10]{0} fusion(f32[100,10]{1,0} %x), kind=kInput,
|
||||
// calls=%fused_computation
|
||||
// }
|
||||
// )",
|
||||
// R"(
|
||||
// ;CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]])
|
||||
// ;CHECK: "xla_lhlo.fusion"() ( {
|
||||
// ;CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]]
|
||||
// ;CHECK: %[[CT0:.*]] = xla_hlo.constant dense<0.000000e+00>
|
||||
// ;CHECK: %[[RED:.*]] = "xla_hlo.reduce"(%0, %1) ( {
|
||||
// ;CHECK: ^bb0(%[[BARG0:.*]]: [[ETYPE:.*]], %[[BARG1:.*]]: [[ETYPE]])
|
||||
// ;CHECK: %[[ADD:.*]] = xla_hlo.add %[[BARG0]], %[[BARG1]] : [[ETYPE]]
|
||||
// ;CHECK: "xla_hlo.return"(%[[ADD]])
|
||||
// ;CHECK: })
|
||||
// ;CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]]
|
||||
// ;CHECK: "xla_lhlo.terminator"()
|
||||
// ;CHECK-NEXT: })
|
||||
// )");
|
||||
// }
|
||||
|
||||
TEST_F(LhloGenTest, Broadcast) {
|
||||
CompileAndVerifyIr(R"(
|
||||
|
Loading…
x
Reference in New Issue
Block a user