Expose mlir_gpu passes to test tool and add a first test.
This just refactors how passes are registered but otherwise is NFC. PiperOrigin-RevId: 333461139 Change-Id: Ice76b8337ba7beb0915fe28dff69ab7e83fbe0cb
This commit is contained in:
parent
68467dc8a9
commit
1e8cb572f2
@ -73,8 +73,8 @@ tool_names = [
|
||||
'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir',
|
||||
'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt'
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-mlir-gpu-opt', 'xla-opt',
|
||||
'hlo_to_llvm_ir', 'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt'
|
||||
]
|
||||
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Description:
|
||||
# MLIR-GPU-specific components in XLA service implementation.
|
||||
|
||||
load("//third_party/mlir:tblgen.bzl", "gentbl")
|
||||
load(
|
||||
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
@ -159,11 +160,20 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "passes_inc_gen",
|
||||
tbl_outs = [("-gen-pass-decls -name XlaMlirGpu", "passes.h.inc")],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "passes.td",
|
||||
td_srcs = ["@llvm-project//mlir:PassBaseTdFiles"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "passes",
|
||||
srcs = ["passes.cc"],
|
||||
hdrs = ["passes.h"],
|
||||
deps = [
|
||||
":passes_inc_gen",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:Support",
|
||||
@ -261,3 +271,20 @@ tf_cc_binary(
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "xla-mlir-gpu-opt",
|
||||
srcs = ["xla_mlir_gpu_opt.cc"],
|
||||
visibility = ["//tensorflow/compiler/xla/service/mlir_gpu/tests:__subpackages__"],
|
||||
deps = [
|
||||
":passes",
|
||||
"//tensorflow/compiler/mlir/hlo:all_passes",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
@ -32,8 +32,10 @@ namespace xla {
|
||||
namespace mlir_gpu {
|
||||
namespace {
|
||||
|
||||
struct FusionOpRemoverPass
|
||||
: public mlir::PassWrapper<FusionOpRemoverPass, ::mlir::FunctionPass> {
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc"
|
||||
|
||||
struct FusionOpRemoverPass : FusionOpRemoverPassBase<FusionOpRemoverPass> {
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([&](mlir::lmhlo::FusionOp op) {
|
||||
mlir::OpBuilder builder(op);
|
||||
@ -52,8 +54,7 @@ struct FusionOpRemoverPass
|
||||
}
|
||||
};
|
||||
|
||||
struct StoreForwardingPass
|
||||
: mlir::PassWrapper<StoreForwardingPass, mlir::FunctionPass> {
|
||||
struct StoreForwardingPass : StoreForwardingPassBase<StoreForwardingPass> {
|
||||
mlir::StoreOp findStore(mlir::Operation* op,
|
||||
std::function<bool(mlir::StoreOp)> matches) {
|
||||
// Search from op upwards in the current block.
|
||||
@ -132,7 +133,7 @@ struct StoreForwardingPass
|
||||
};
|
||||
|
||||
struct DeadTempBufferRemovalPass
|
||||
: mlir::PassWrapper<DeadTempBufferRemovalPass, ::mlir::FunctionPass> {
|
||||
: DeadTempBufferRemovalPassBase<DeadTempBufferRemovalPass> {
|
||||
bool operationConsideredDead(mlir::Operation* op) {
|
||||
for (auto result : op->getResults()) {
|
||||
if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) {
|
||||
@ -183,8 +184,8 @@ struct DeadTempBufferRemovalPass
|
||||
};
|
||||
|
||||
struct MoveScalarComputationsIntoGpuLaunchPass
|
||||
: mlir::PassWrapper<MoveScalarComputationsIntoGpuLaunchPass,
|
||||
mlir::FunctionPass> {
|
||||
: MoveScalarComputationsIntoGpuLaunchPassBase<
|
||||
MoveScalarComputationsIntoGpuLaunchPass> {
|
||||
static bool isInliningBeneficiary(mlir::Operation* op) {
|
||||
return llvm::isa<mlir::ConstantOp, mlir::DimOp, mlir::SelectOp,
|
||||
mlir::CmpIOp>(op);
|
||||
@ -234,14 +235,13 @@ struct MoveScalarComputationsIntoGpuLaunchPass
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
mlir::FuncOp fun = getFunction();
|
||||
fun.walk(
|
||||
getFunction().walk(
|
||||
[](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); });
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteKernelSignaturePass
|
||||
: mlir::PassWrapper<RewriteKernelSignaturePass, mlir::FunctionPass> {
|
||||
: RewriteKernelSignaturePassBase<RewriteKernelSignaturePass> {
|
||||
void runOnFunction() override {
|
||||
mlir::FuncOp func = getFunction();
|
||||
mlir::ModuleOp module = func.getParentOfType<mlir::ModuleOp>();
|
||||
@ -349,15 +349,14 @@ struct RewriteKernelSignaturePass
|
||||
}
|
||||
};
|
||||
|
||||
struct MapParallelLoopsPass
|
||||
: public mlir::PassWrapper<MapParallelLoopsPass, mlir::FunctionPass> {
|
||||
struct MapParallelLoopsPass : MapParallelLoopsPassBase<MapParallelLoopsPass> {
|
||||
void runOnFunction() override {
|
||||
mlir::greedilyMapParallelSCFToGPU(getFunction().getBody());
|
||||
}
|
||||
};
|
||||
|
||||
struct FuseInnerParallelLoopsPass
|
||||
: public mlir::PassWrapper<FuseInnerParallelLoopsPass, mlir::FunctionPass> {
|
||||
: FuseInnerParallelLoopsPassBase<FuseInnerParallelLoopsPass> {
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([](mlir::scf::ParallelOp op) {
|
||||
mlir::scf::naivelyFuseParallelOps(op.region());
|
||||
@ -366,12 +365,10 @@ struct FuseInnerParallelLoopsPass
|
||||
};
|
||||
|
||||
struct ParallelLoopCollapsingToFirstDimPass
|
||||
: public mlir::PassWrapper<ParallelLoopCollapsingToFirstDimPass,
|
||||
mlir::OperationPass<mlir::ModuleOp>> {
|
||||
void runOnOperation() override {
|
||||
mlir::Operation* module = getOperation();
|
||||
|
||||
module->walk([&](mlir::scf::ParallelOp op) {
|
||||
: ParallelLoopCollapsingToFirstDimPassBase<
|
||||
ParallelLoopCollapsingToFirstDimPass> {
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([&](mlir::scf::ParallelOp op) {
|
||||
unsigned num_loops = op.getNumLoops();
|
||||
std::vector<unsigned> combinedLoops;
|
||||
combinedLoops.reserve(num_loops);
|
||||
@ -414,7 +411,7 @@ std::unique_ptr<mlir::FunctionPass> createMapParallelLoopsPass() {
|
||||
return absl::make_unique<MapParallelLoopsPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
std::unique_ptr<mlir::FunctionPass>
|
||||
createParallelLoopCollapsingToFirstDimPass() {
|
||||
return absl::make_unique<ParallelLoopCollapsingToFirstDimPass>();
|
||||
}
|
||||
|
@ -57,9 +57,12 @@ std::unique_ptr<mlir::FunctionPass> createFuseInnerParallelLoopsPass();
|
||||
std::unique_ptr<mlir::FunctionPass> createMapParallelLoopsPass();
|
||||
|
||||
/// Collapses all loop dimension into the first one.
|
||||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
||||
std::unique_ptr<mlir::FunctionPass>
|
||||
createParallelLoopCollapsingToFirstDimPass();
|
||||
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h.inc"
|
||||
|
||||
} // namespace mlir_gpu
|
||||
} // namespace xla
|
||||
|
||||
|
100
tensorflow/compiler/xla/service/mlir_gpu/passes.td
Normal file
100
tensorflow/compiler/xla/service/mlir_gpu/passes.td
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def FusionOpRemoverPass : FunctionPass<"mlir-gpu-fusion-op-remover"> {
|
||||
let summary = "Removes lhlo fusion ops by inlining their regions.";
|
||||
let constructor = "createFusionOpRemoverPass()";
|
||||
let description = [{
|
||||
Replaces a FusionOp by the operations contained in its region.
|
||||
}];
|
||||
}
|
||||
|
||||
def StoreForwardingPass : FunctionPass<"mlir-gpu-store-forwarding"> {
|
||||
let summary = "Limited pass to forward stores to loads.";
|
||||
let constructor = "createStoreForwardingPass()";
|
||||
let description = [{
|
||||
Replaces a load that immediately follows a store to the same address with
|
||||
the stored value.
|
||||
}];
|
||||
}
|
||||
|
||||
def DeadTempBufferRemovalPass
|
||||
: FunctionPass<"mlir-gpu-dead-temp-buffer-removal"> {
|
||||
let summary = "Removal of dead temp buffers.";
|
||||
let constructor = "createDeadTempBufferRemovalPass()";
|
||||
let description = [{
|
||||
Removes temporary buffers that are only written to but never read from or
|
||||
that are read but the read value is not used. Needs an analysis that proves
|
||||
that loads and stores are side-effect free (in bounds, no aliasing, etc.).
|
||||
}];
|
||||
}
|
||||
|
||||
def MoveScalarComputationsIntoGpuLaunchPass
|
||||
: FunctionPass<"mlir-gpu-inline-scalar-computation"> {
|
||||
let summary = "Pass to Move scalar computations to the GPULaunchOp body.";
|
||||
let constructor = "createMoveScalarComputationsIntoGpuLaunchPass()";
|
||||
let description = [{
|
||||
Moves scalar computations to the GPULaunchOp body.
|
||||
}];
|
||||
}
|
||||
|
||||
def RewriteKernelSignaturePass
|
||||
: FunctionPass<"mlir-gpu-rewrite-signatures"> {
|
||||
let summary = "Rewrite kernel signatures to be deterministic.";
|
||||
let constructor = "createRewriteKernelSignaturePass()";
|
||||
let description = [{
|
||||
Sorts the operands to the kernel for a deterministic order. First operands
|
||||
that are defined by function arguments, followed by operands that are
|
||||
returned from the function. This only works for simple functions without
|
||||
control flow and can be used in cases where the kernel is extracted and used
|
||||
independently of the host-side code.
|
||||
}];
|
||||
}
|
||||
|
||||
def MapParallelLoopsPass
|
||||
: FunctionPass<"mlir-gpu-map-parallel-loops"> {
|
||||
let summary = "Greedily maps loops to GPU hardware dimensions.";
|
||||
let constructor = "createMapParallelLoopsPass()";
|
||||
let description = [{
|
||||
Greedily maps loops to GPU hardware dimensions.
|
||||
}];
|
||||
}
|
||||
|
||||
def FuseInnerParallelLoopsPass
|
||||
: FunctionPass<"mlir-gpu-fuse-inner-parallel-loops"> {
|
||||
let summary = "Limited pass to forward stores to loads.";
|
||||
let constructor = "createFuseInnerParallelLoopsPass()";
|
||||
let description = [{
|
||||
Directs parallel loop fusion to the inner loops. This cannot be done with
|
||||
a passmanager alone ATM, as nested pass managers require operations to
|
||||
be closed from above.
|
||||
}];
|
||||
}
|
||||
|
||||
def ParallelLoopCollapsingToFirstDimPass
|
||||
: FunctionPass<"mlir-gpu-collapse-parallel-loops"> {
|
||||
let summary = "Collaps n-dimensional loops into one-dimensional ones.";
|
||||
let constructor = "createParallelLoopCollapsingToFirstDimPass()";
|
||||
let description = [{
|
||||
Collapses all loop dimension of a parallel loop into the first one.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_TD_
|
@ -18,13 +18,16 @@ package_group(
|
||||
)
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
data = [
|
||||
":test_utilities",
|
||||
"@llvm-project//mlir:run_lit.sh",
|
||||
],
|
||||
default_tags = tf_cuda_tests_tags() + [
|
||||
"no_pip",
|
||||
"config-cuda-only",
|
||||
"no_rocm",
|
||||
],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
driver = "//tensorflow/compiler/mlir:run_lit.sh",
|
||||
exclude = [
|
||||
# TODO(b/137624192): Reenable once we can fuse reductions.
|
||||
"fused_reduce.hlo",
|
||||
|
24
tensorflow/compiler/xla/service/mlir_gpu/tests/passes/BUILD
Normal file
24
tensorflow/compiler/xla/service/mlir_gpu/tests/passes/BUILD
Normal file
@ -0,0 +1,24 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
glob_lit_tests(
|
||||
data = [
|
||||
":test_utilities",
|
||||
"@llvm-project//mlir:run_lit.sh",
|
||||
],
|
||||
driver = "//tensorflow/compiler/mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/xla/service/mlir_gpu:xla-mlir-gpu-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
@ -0,0 +1,60 @@
|
||||
// RUN: xla-mlir-gpu-opt --mlir-gpu-dead-temp-buffer-removal %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @dead
|
||||
func @dead() {
|
||||
// CHECK-NOT: alloc
|
||||
%0 = alloc() : memref<42xi32>
|
||||
%c0 = constant 0 : i32
|
||||
%c12 = constant 12 : index
|
||||
store %c0, %0[%c12] : memref<42xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dead_load
|
||||
func @dead_load() {
|
||||
// CHECK-NOT: alloc
|
||||
%0 = alloc() : memref<42xi32>
|
||||
%c0 = constant 0 : i32
|
||||
%c12 = constant 12 : index
|
||||
store %c0, %0[%c12] : memref<42xi32>
|
||||
%1 = load %0[%c12] : memref<42xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @used_load
|
||||
func @used_load() -> i32 {
|
||||
// CHECK: alloc
|
||||
%0 = alloc() : memref<42xi32>
|
||||
%c0 = constant 0 : i32
|
||||
%c12 = constant 12 : index
|
||||
store %c0, %0[%c12] : memref<42xi32>
|
||||
%1 = load %0[%c12] : memref<42xi32>
|
||||
return %1 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dead_subview
|
||||
func @dead_subview() {
|
||||
// CHECK-NOT: alloc
|
||||
%0 = alloc() : memref<42xi32>
|
||||
%c0 = constant 0 : i32
|
||||
%c1 = constant 1 : index
|
||||
%c4 = constant 4 : index
|
||||
%c12 = constant 12 : index
|
||||
store %c0, %0[%c12] : memref<42xi32>
|
||||
%1 = subview %0[%c12][%c4][%c1] : memref<42xi32> to memref<?xi32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @used_subview
|
||||
func @used_subview() -> i32 {
|
||||
// CHECK: alloc
|
||||
%0 = alloc() : memref<42xi32>
|
||||
%c0 = constant 0 : i32
|
||||
%c1 = constant 1 : index
|
||||
%c4 = constant 4 : index
|
||||
%c12 = constant 12 : index
|
||||
store %c0, %0[%c12] : memref<42xi32>
|
||||
%1 = subview %0[%c12][%c4][%c1] : memref<42xi32> to memref<?xi32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
|
||||
%2 = load %1[%c1] : memref<?xi32, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>>
|
||||
return %2 : i32
|
||||
}
|
35
tensorflow/compiler/xla/service/mlir_gpu/xla_mlir_gpu_opt.cc
Normal file
35
tensorflow/compiler/xla/service/mlir_gpu/xla_mlir_gpu_opt.cc
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||
#include "mlir/InitAllPasses.h" // from @llvm-project
|
||||
#include "mlir/Support/MlirOptMain.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h"
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerAllPasses();
|
||||
mlir::mhlo::registerAllMhloPasses();
|
||||
mlir::lmhlo::registerAllLmhloPasses();
|
||||
xla::mlir_gpu::registerXlaMlirGpuPasses();
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerAllDialects(registry);
|
||||
mlir::mhlo::registerAllMhloDialects(registry);
|
||||
|
||||
return failed(mlir::MlirOptMain(
|
||||
argc, argv, "XLA mlir gpu backend pass driver\n", registry));
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user