Expose mlir_gpu passes to test tool and add a first test.

This just refactors how passes are registered but otherwise is NFC.

PiperOrigin-RevId: 333461139
Change-Id: Ice76b8337ba7beb0915fe28dff69ab7e83fbe0cb
This commit is contained in:
Stephan Herhut 2020-09-24 00:44:31 -07:00 committed by TensorFlower Gardener
parent 68467dc8a9
commit 1e8cb572f2
9 changed files with 274 additions and 25 deletions

View File

@ -73,8 +73,8 @@ tool_names = [
'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir',
'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt'
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
'hlo_to_llvm_ir', 'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt'
]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -1,6 +1,7 @@
# Description:
# MLIR-GPU-specific components in XLA service implementation.
load("//third_party/mlir:tblgen.bzl", "gentbl")
load(
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
@ -159,11 +160,20 @@ cc_library(
],
)
gentbl(
name = "passes_inc_gen",
tbl_outs = [("-gen-pass-decls -name XlaMlirGpu", "passes.h.inc")],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "passes.td",
td_srcs = ["@llvm-project//mlir:PassBaseTdFiles"],
)
cc_library(
name = "passes",
srcs = ["passes.cc"],
hdrs = ["passes.h"],
deps = [
":passes_inc_gen",
"//tensorflow/compiler/mlir/hlo:lhlo",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
@ -261,3 +271,20 @@ tf_cc_binary(
"@llvm-project//mlir:Support",
],
)
tf_cc_binary(
name = "xla-mlir-gpu-opt",
srcs = ["xla_mlir_gpu_opt.cc"],
visibility = ["//tensorflow/compiler/xla/service/mlir_gpu/tests:__subpackages__"],
deps = [
":passes",
"//tensorflow/compiler/mlir/hlo:all_passes",
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

View File

@ -32,8 +32,10 @@ namespace xla {
namespace mlir_gpu {
namespace {
struct FusionOpRemoverPass
: public mlir::PassWrapper<FusionOpRemoverPass, ::mlir::FunctionPass> {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc"
struct FusionOpRemoverPass : FusionOpRemoverPassBase<FusionOpRemoverPass> {
void runOnFunction() override {
getFunction().walk([&](mlir::lmhlo::FusionOp op) {
mlir::OpBuilder builder(op);
@ -52,8 +54,7 @@ struct FusionOpRemoverPass
}
};
struct StoreForwardingPass
: mlir::PassWrapper<StoreForwardingPass, mlir::FunctionPass> {
struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
mlir::StoreOp findStore(mlir::Operation* op,
std::function<bool(mlir::StoreOp)> matches) {
// Search from op upwards in the current block.
@ -132,7 +133,7 @@ struct StoreForwardingPass
};
struct DeadTempBufferRemovalPass
: mlir::PassWrapper<DeadTempBufferRemovalPass, ::mlir::FunctionPass> {
: DeadTempBufferRemovalPassBase<DeadTempBufferRemovalPass> {
bool operationConsideredDead(mlir::Operation* op) {
for (auto result : op->getResults()) {
if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
@ -183,8 +184,8 @@ struct DeadTempBufferRemovalPass
};
struct MoveScalarComputationsIntoGpuLaunchPass
: mlir::PassWrapper<MoveScalarComputationsIntoGpuLaunchPass,
mlir::FunctionPass> {
: MoveScalarComputationsIntoGpuLaunchPassBase<
MoveScalarComputationsIntoGpuLaunchPass> {
static bool isInliningBeneficiary(mlir::Operation* op) {
return llvm::isa<mlir::ConstantOp, mlir::DimOp, mlir::SelectOp,
mlir::CmpIOp>(op);
@ -234,14 +235,13 @@ struct MoveScalarComputationsIntoGpuLaunchPass
}
void runOnFunction() override {
mlir::FuncOp fun = getFunction();
fun.walk(
getFunction().walk(
[](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); });
}
};
struct RewriteKernelSignaturePass
: mlir::PassWrapper<RewriteKernelSignaturePass, mlir::FunctionPass> {
: RewriteKernelSignaturePassBase<RewriteKernelSignaturePass> {
void runOnFunction() override {
mlir::FuncOp func = getFunction();
mlir::ModuleOp module = func.getParentOfType<mlir::ModuleOp>();
@ -349,15 +349,14 @@ struct RewriteKernelSignaturePass
}
};
struct MapParallelLoopsPass
: public mlir::PassWrapper<MapParallelLoopsPass, mlir::FunctionPass> {
struct MapParallelLoopsPass : MapParallelLoopsPassBase<MapParallelLoopsPass> {
void runOnFunction() override {
mlir::greedilyMapParallelSCFToGPU(getFunction().getBody());
}
};
struct FuseInnerParallelLoopsPass
: public mlir::PassWrapper<FuseInnerParallelLoopsPass, mlir::FunctionPass> {
: FuseInnerParallelLoopsPassBase<FuseInnerParallelLoopsPass> {
void runOnFunction() override {
getFunction().walk([](mlir::scf::ParallelOp op) {
mlir::scf::naivelyFuseParallelOps(op.region());
@ -366,12 +365,10 @@ struct FuseInnerParallelLoopsPass
};
struct ParallelLoopCollapsingToFirstDimPass
: public mlir::PassWrapper<ParallelLoopCollapsingToFirstDimPass,
mlir::OperationPass<mlir::ModuleOp>> {
void runOnOperation() override {
mlir::Operation* module = getOperation();
module->walk([&](mlir::scf::ParallelOp op) {
: ParallelLoopCollapsingToFirstDimPassBase<
ParallelLoopCollapsingToFirstDimPass> {
void runOnFunction() override {
getFunction().walk([&](mlir::scf::ParallelOp op) {
unsigned num_loops = op.getNumLoops();
std::vector<unsigned> combinedLoops;
combinedLoops.reserve(num_loops);
@ -414,7 +411,7 @@ std::unique_ptr<mlir::FunctionPass> createMapParallelLoopsPass() {
return absl::make_unique<MapParallelLoopsPass>();
}
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
std::unique_ptr<mlir::FunctionPass>
createParallelLoopCollapsingToFirstDimPass() {
return absl::make_unique<ParallelLoopCollapsingToFirstDimPass>();
}

View File

@ -57,9 +57,12 @@ std::unique_ptr<mlir::FunctionPass> createFuseInnerParallelLoopsPass();
std::unique_ptr<mlir::FunctionPass> createMapParallelLoopsPass();
/// Collapses all loop dimension into the first one.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
std::unique_ptr<mlir::FunctionPass>
createParallelLoopCollapsingToFirstDimPass();
#define GEN_PASS_REGISTRATION
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc"
} // namespace mlir_gpu
} // namespace xla

View File

@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_
include "mlir/Pass/PassBase.td"
def FusionOpRemoverPass : FunctionPass<"mlir-gpu-fusion-op-remover"> {
let summary = "Removes lhlo fusion ops by inlining their regions.";
let constructor = "createFusionOpRemoverPass()";
let description = [{
Replaces a FusionOp by the operations contained in its region.
}];
}
def StoreForwardingPass : FunctionPass<"mlir-gpu-store-forwarding"> {
let summary = "Limited pass to forward stores to loads.";
let constructor = "createStoreForwardingPass()";
let description = [{
Replaces a load that immediately follows a store to the same address with
the stored value.
}];
}
def DeadTempBufferRemovalPass
: FunctionPass<"mlir-gpu-dead-temp-buffer-removal"> {
let summary = "Removal of dead temp buffers.";
let constructor = "createDeadTempBufferRemovalPass()";
let description = [{
Removes temporary buffers that are only written to but never read from or
that are read but the read value is not used. Needs an analysis that proves
that loads and stores are side-effect free (in bounds, no aliasing, etc.).
}];
}
def MoveScalarComputationsIntoGpuLaunchPass
: FunctionPass<"mlir-gpu-inline-scalar-computation"> {
let summary = "Pass to Move scalar computations to the GPULaunchOp body.";
let constructor = "createMoveScalarComputationsIntoGpuLaunchPass()";
let description = [{
Moves scalar computations to the GPULaunchOp body.
}];
}
def RewriteKernelSignaturePass
: FunctionPass<"mlir-gpu-rewrite-signatures"> {
let summary = "Rewrite kernel signatures to be deterministic.";
let constructor = "createRewriteKernelSignaturePass()";
let description = [{
Sorts the operands to the kernel for a deterministic order. First operands
that are defined by function arguments, followed by operands that are
returned from the function. This only works for simple functions without
control flow and can be used in cases where the kernel is extracted and used
independently of the host-side code.
}];
}
def MapParallelLoopsPass
: FunctionPass<"mlir-gpu-map-parallel-loops"> {
let summary = "Greedily maps loops to GPU hardware dimensions.";
let constructor = "createMapParallelLoopsPass()";
let description = [{
Greedily maps loops to GPU hardware dimensions.
}];
}
def FuseInnerParallelLoopsPass
: FunctionPass<"mlir-gpu-fuse-inner-parallel-loops"> {
let summary = "Limited pass to forward stores to loads.";
let constructor = "createFuseInnerParallelLoopsPass()";
let description = [{
Directs parallel loop fusion to the inner loops. This cannot be done with
a passmanager alone ATM, as nested pass managers require operations to
be closed from above.
}];
}
def ParallelLoopCollapsingToFirstDimPass
: FunctionPass<"mlir-gpu-collapse-parallel-loops"> {
let summary = "Collaps n-dimensional loops into one-dimensional ones.";
let constructor = "createParallelLoopCollapsingToFirstDimPass()";
let description = [{
Collapses all loop dimension of a parallel loop into the first one.
}];
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_

View File

@ -18,13 +18,16 @@ package_group(
)
glob_lit_tests(
data = [":test_utilities"],
data = [
":test_utilities",
"@llvm-project//mlir:run_lit.sh",
],
default_tags = tf_cuda_tests_tags() + [
"no_pip",
"config-cuda-only",
"no_rocm",
],
driver = "@llvm-project//mlir:run_lit.sh",
driver = "//tensorflow/compiler/mlir:run_lit.sh",
exclude = [
# TODO(b/137624192): Reenable once we can fuse reductions.
"fused_reduce.hlo",

View File

@ -0,0 +1,24 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(
licenses = ["notice"], # Apache 2.0
)
glob_lit_tests(
data = [
":test_utilities",
"@llvm-project//mlir:run_lit.sh",
],
driver = "//tensorflow/compiler/mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/xla/service/mlir_gpu:xla-mlir-gpu-opt",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,60 @@
// RUN: xla-mlir-gpu-opt --mlir-gpu-dead-temp-buffer-removal %s | FileCheck %s
// CHECK-LABEL: @dead
func @dead() {
// CHECK-NOT: alloc
%0 = alloc() : memref<42xi32>
%c0 = constant 0 : i32
%c12 = constant 12 : index
store %c0, %0[%c12] : memref<42xi32>
return
}
// CHECK-LABEL: @dead_load
func @dead_load() {
// CHECK-NOT: alloc
%0 = alloc() : memref<42xi32>
%c0 = constant 0 : i32
%c12 = constant 12 : index
store %c0, %0[%c12] : memref<42xi32>
%1 = load %0[%c12] : memref<42xi32>
return
}
// CHECK-LABEL: @used_load
func @used_load() -> i32 {
// CHECK: alloc
%0 = alloc() : memref<42xi32>
%c0 = constant 0 : i32
%c12 = constant 12 : index
store %c0, %0[%c12] : memref<42xi32>
%1 = load %0[%c12] : memref<42xi32>
return %1 : i32
}
// CHECK-LABEL: @dead_subview
func @dead_subview() {
// CHECK-NOT: alloc
%0 = alloc() : memref<42xi32>
%c0 = constant 0 : i32
%c1 = constant 1 : index
%c4 = constant 4 : index
%c12 = constant 12 : index
store %c0, %0[%c12] : memref<42xi32>
%1 = subview %0[%c12][%c4][%c1] : memref<42xi32> to memref<?xi32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
return
}
// CHECK-LABEL: @used_subview
func @used_subview() -> i32 {
// CHECK: alloc
%0 = alloc() : memref<42xi32>
%c0 = constant 0 : i32
%c1 = constant 1 : index
%c4 = constant 4 : index
%c12 = constant 12 : index
store %c0, %0[%c12] : memref<42xi32>
%1 = subview %0[%c12][%c4][%c1] : memref<42xi32> to memref<?xi32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
%2 = load %1[%c1] : memref<?xi32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
return %2 : i32
}

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Support/MlirOptMain.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
xla::mlir_gpu::registerXlaMlirGpuPasses();
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
mlir::mhlo::registerAllMhloDialects(registry);
return failed(mlir::MlirOptMain(
argc, argv, "XLA mlir gpu backend pass driver\n", registry));
}