From dc481995721dbcf557e7d7d0a193c50f1e91990b Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Tue, 3 Sep 2019 04:56:33 -0700 Subject: [PATCH] Add lowering pipeline from LHLO to NVVM. This adds a first simple lowering pipeline to the mlir_gpu compiler that takes an mlir module containing LHLO to a module with kernel functions containing LLVM/NVVM dialect only. PiperOrigin-RevId: 266903498 --- .../compiler/xla/service/mlir_gpu/BUILD | 27 ++++ .../xla/service/mlir_gpu/kernel_lowering.cc | 123 ++++++++++++++++++ .../xla/service/mlir_gpu/kernel_lowering.h | 32 +++++ .../xla/service/mlir_gpu/mlir_compiler.cc | 27 +++- .../xla/service/mlir_gpu/mlir_compiler.h | 4 +- .../service/mlir_gpu/mlir_irgen_test_base.cc | 8 +- .../service/mlir_gpu/mlir_irgen_test_base.h | 7 +- .../mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc | 43 ++++++ .../mlir_gpu/transforms/legalize_to_affine.h | 2 +- 9 files changed, 263 insertions(+), 10 deletions(-) create mode 100644 tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc create mode 100644 tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index a1cc01c2221..d461c30c620 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -36,7 +36,10 @@ cc_library( hdrs = ["mlir_compiler.h"], deps = [ ":failover_compiler", + ":kernel_lowering", ":lhlo_dialect_emitter", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:hlo", @@ -49,6 +52,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/core:lib", "//tensorflow/stream_executor:stream_executor_headers", + "@local_config_mlir//:GPUDialect", "@local_config_mlir//:IR", "@local_config_mlir//:LLVMDialect", ], @@ -92,6 +96,29 @@ cc_library( ], ) +cc_library( + name = "kernel_lowering", + srcs = ["kernel_lowering.cc"], + hdrs = ["kernel_lowering.h"], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/mlir_gpu/transforms:legalize_to_affine", + "@com_google_absl//absl/memory", + "@local_config_mlir//:GPUDialect", + "@local_config_mlir//:GPUToNVVMTransforms", + "@local_config_mlir//:GPUTransforms", + "@local_config_mlir//:IR", + "@local_config_mlir//:LLVMDialect", + "@local_config_mlir//:LLVMTransforms", + "@local_config_mlir//:LoopsToGPUPass", + "@local_config_mlir//:NVVMDialect", + "@local_config_mlir//:Pass", + "@local_config_mlir//:Transforms", + ], +) + cc_library( name = "mlir_irgen_test_base", testonly = True, diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc new file mode 100644 index 00000000000..236091e81e5 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -0,0 +1,123 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" + +#include + +#include "absl/memory/memory.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:local_config_mlir +#include "mlir/Conversion/LoopsToGPU/LoopsToGPUPass.h" // TF:local_config_mlir +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // TF:local_config_mlir +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // TF:local_config_mlir +#include "mlir/Dialect/GPU/GPUDialect.h" // TF:local_config_mlir +#include "mlir/Dialect/GPU/Passes.h" // TF:local_config_mlir +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir +#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/service/mlir_gpu/transforms/legalize_to_affine.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace mlir_gpu { +namespace { + +using ::mlir::ConversionTarget; +using ::mlir::FuncOp; +using ::mlir::LLVMTypeConverter; +using ::mlir::ModulePass; +using ::mlir::ModulePassBase; +using ::mlir::OwningRewritePatternList; +using ::mlir::PassManager; +using ::mlir::gpu::GPUDialect; +using ::mlir::LLVM::LLVMDialect; +using ::mlir::NVVM::NVVMDialect; + +struct LowerKernelBodiesToNVVMPass + : public ModulePass { + public: + explicit LowerKernelBodiesToNVVMPass() = default; + + void runOnModule() override { + auto module = getModule(); + ConversionTarget target(*module.getContext()); + LLVMTypeConverter converter(module.getContext()); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + + OwningRewritePatternList patterns; + populateStdToLLVMConversionPatterns(converter, patterns); + populateGpuToNVVMConversionPatterns(converter, patterns); + + module.walk([this, &target, &patterns, &converter](FuncOp function) { + if (!GPUDialect::isKernel(function)) { + return; + } + if (failed(applyFullConversion(function, target, patterns, &converter))) { + signalPassFailure(); + } + }); + } +}; + +} // namespace + +Status LowerLHLOToGPU(mlir::ModuleOp module) { + PassManager pm(module.getContext()); + + // Transform element-wise operations to Affine. + pm.addPass(createLegalizeAffinePass()); + // Transform affine to gpu launches. + // TODO(b/137624192) This pass requires known dimensions. Generalization it. + pm.addPass(::mlir::createSimpleLoopsToGPUPass(/*numBlockDims=*/0, + /*numThreadDims=*/2)); + // Take launches to launches with kernels. + pm.addPass(::mlir::createGpuKernelOutliningPass()); + // Some basic cleanup. + pm.addPass(::mlir::createCSEPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering to NVVM IR failed."); + } + return Status::OK(); +} + +Status LowerKernelBodiesToNVVM(mlir::ModuleOp module) { + // We cannot verify as the signature of the kernel is rewritten. + PassManager pm(module.getContext(), /*verifyPasses=*/false); + + // Rewrite kernel functions to LLVM IR. + pm.addPass(absl::make_unique()); + // Some basic cleanup. + pm.addPass(::mlir::createCSEPass()); + + if (failed(pm.run(module))) { + return InternalError("Lowering to NVVM IR failed."); + } + return Status::OK(); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h new file mode 100644 index 00000000000..1c1bc9eb722 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_KERNEL_LOWERING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_KERNEL_LOWERING_H_ + +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace mlir_gpu { + +Status LowerLHLOToGPU(mlir::ModuleOp module); + +Status LowerKernelBodiesToNVVM(mlir::ModuleOp module); + +} // namespace mlir_gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_KERNEL_LOWERING_H_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index d240003b039..eef17132efa 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -17,10 +17,13 @@ limitations under the License. #include +#include "mlir/Dialect/GPU/GPUDialect.h" // TF:local_config_mlir #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" @@ -30,9 +33,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" #include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" #include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -152,7 +158,22 @@ StatusOr> MlirCompiler::RunBackend( TF_RETURN_IF_ERROR( lhlo_emitter.EmitComputation(*module->entry_computation())); - if (module_hook_.callback && !module_hook_.apply_on_lowered) { + if (module_hook_.callback && + module_hook_.stage == IRHook::LoweringStage::LHLO) { + module_hook_.callback(*mlir_module); + } + + TF_RETURN_IF_ERROR(LowerLHLOToGPU(*mlir_module)); + + if (module_hook_.callback && + module_hook_.stage == IRHook::LoweringStage::GPU) { + module_hook_.callback(*mlir_module); + } + + TF_RETURN_IF_ERROR(LowerKernelBodiesToNVVM(*mlir_module)); + + if (module_hook_.callback && + module_hook_.stage == IRHook::LoweringStage::LLVM) { module_hook_.callback(*mlir_module); } @@ -194,7 +215,9 @@ void MlirCompiler::SetModuleHook(IRHook module_hook) { module_hook_ = module_hook; } -void MlirCompiler::RemoveModuleHook() { module_hook_ = {nullptr, false}; } +void MlirCompiler::RemoveModuleHook() { + module_hook_ = {nullptr, IRHook::LoweringStage::LHLO}; +} } // namespace mlir_gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index fdc71903a06..b3d9c4a8085 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -57,8 +57,10 @@ class MlirCompiler : public Compiler { } struct IRHook { + enum class LoweringStage { LHLO, GPU, LLVM }; + std::function callback; - bool apply_on_lowered; + LoweringStage stage; }; void SetModuleHook(IRHook module_hook); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc index 4b6a03270c7..dfa3af8c39f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc @@ -31,7 +31,7 @@ namespace mlir_gpu { void MlirIrGenTestBase::CompileAndVerifyIr( std::unique_ptr hlo_module, const string& pattern, - bool match_lowered_ir) { + LoweringStage stage) { MlirCompiler* compiler = GetMLIRCompiler(); string ir; compiler->SetModuleHook({[&ir](mlir::ModuleOp module) -> Status { @@ -42,7 +42,7 @@ void MlirIrGenTestBase::CompileAndVerifyIr( ir = buffer_string; return Status::OK(); }, - match_lowered_ir}); + stage}); Status status = CompileToExecutable(std::move(hlo_module)).status(); compiler->RemoveModuleHook(); TF_ASSERT_OK(status); @@ -54,12 +54,12 @@ void MlirIrGenTestBase::CompileAndVerifyIr( void MlirIrGenTestBase::CompileAndVerifyIr(const string& hlo_text, const string& expected_llvm_ir, - bool match_lowered_ir) { + LoweringStage stage) { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule(hlo_text, config)); - CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_lowered_ir); + CompileAndVerifyIr(std::move(module), expected_llvm_ir, stage); } MlirCompiler* MlirIrGenTestBase::GetMLIRCompiler() { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h index 613ddc27bf6..c4990ab50d3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h @@ -27,6 +27,8 @@ namespace mlir_gpu { // Tests that verify IR emitted by the CPU/GPU backend is as expected. class MlirIrGenTestBase : public CodegenTestBase { protected: + using LoweringStage = MlirCompiler::IRHook::LoweringStage; + // Compiles the given HLO module to MLIR IR and verifies the IR matches the // given pattern. `pattern` is in the FileCheck pattern matching syntax // (http://llvm.org/docs/CommandGuide/FileCheck.html). @@ -37,13 +39,14 @@ class MlirIrGenTestBase : public CodegenTestBase { // steps to LLVM IR are applied; otherwise, the IR before lowering is // matched. void CompileAndVerifyIr(std::unique_ptr hlo_module, - const string& pattern, bool match_lowered_ir = false); + const string& pattern, + LoweringStage stage = LoweringStage::LHLO); // A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create // an HLO module. void CompileAndVerifyIr(const string& hlo_text, const string& expected_llvm_ir, - bool match_lowered_ir = false); + LoweringStage stage = LoweringStage::LHLO); // Compiles and returns module with optimizations from a given HLO. StatusOr> GetOptimizedModule( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc index 657f9dbe061..5eb559b7813 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -47,6 +47,49 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { )"); } +TEST_F(LhloGenTest, AddInGPUDialect) { + CompileAndVerifyIr(R"( +HloModule Add + +ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +})", + R"( +;CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { +;CHECK: "gpu.launch_func"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[ARG0]], %[[ARG1]], %[[ARG2]] +;CHECK: } +;CHECK: func @add_kernel(%[[ARG0]]: [[TYPE]], %[[ARG1]]: [[TYPE]], %[[ARG2]]: [[TYPE]] +;CHECK: load %[[ARG0]][[INDEX:.*]] +;CHECK: load %[[ARG1]][[INDEX]] +;CHECK: store %{{.*}}, %[[ARG2]][[INDEX]] + )", + LoweringStage::GPU); +} + +TEST_F(LhloGenTest, AddInLVVMDialect) { + CompileAndVerifyIr(R"( +HloModule Add + +ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +})", + R"( +;CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm<.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]] +;CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ARG0]][[INDEX:.*]] +;CHECK: %[[VAL0:.*]] = llvm.load %[[GEP0]] +;CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[ARG1]][[INDEX]] +;CHECK: %[[VAL1:.*]] = llvm.load %[[GEP1]] +;CHECK: %[[VAL2:.*]] = llvm.fadd %[[VAL0]], %[[VAL1]] +;CHECK: %[[GEP2:.*]] = llvm.getelementptr %[[ARG2]][[INDEX]] +;CHECK: llvm.store %[[VAL2]], %[[GEP2]] + )", + LoweringStage::LLVM); +} + TEST_F(LhloGenTest, AddMultiply) { CompileAndVerifyIr(R"( HloModule AddMultiply diff --git a/tensorflow/compiler/xla/service/mlir_gpu/transforms/legalize_to_affine.h b/tensorflow/compiler/xla/service/mlir_gpu/transforms/legalize_to_affine.h index aee4fe37a3d..419bca4ad00 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/transforms/legalize_to_affine.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/transforms/legalize_to_affine.h @@ -24,7 +24,7 @@ namespace xla { namespace mlir_gpu { // Lowers from LHLO dialect to affine dialect. -std::unique_ptr<::mlir::FunctionPassBase> createLegalizeAffine(); +std::unique_ptr<::mlir::FunctionPassBase> createLegalizeAffinePass(); // Adds patterns to convert LHLO binary ops to affine loops. void AppendBinaryOpsPatterns(::mlir::MLIRContext* context,