Move MapParallelLoopsPass to kernel _generator directory.

PiperOrigin-RevId: 348611655
Change-Id: Id2341c80f639018e5b4c92ddf82e722763d22917
This commit is contained in:
Stephan Herhut 2020-12-22 04:06:50 -08:00 committed by TensorFlower Gardener
parent 75e4d1a904
commit f72495d54c
5 changed files with 58 additions and 2 deletions

View File

@ -157,7 +157,8 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
// Greedily map the remaining loop to GPU hardware dimensions.
pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
pm.addNestedPass<::mlir::FuncOp>(
mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
// Now lower the shape computations, bufferize all remaining ops and insert
// deallocs.

View File

@ -77,6 +77,7 @@ cc_library(
"embed_memref_prints.cc",
"embed_tf_framework_pass.cc",
"gpu_kernel_to_blob_pass.cc",
"map_parallel_loops_to_gpu.cc",
"parallel_loops_to_sequential.cc",
"same_shape_propagation.cc",
"shape_to_descriptors_pass.cc",
@ -103,6 +104,7 @@ cc_library(
":tf_framework_legalize_to_llvm",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",

View File

@ -0,0 +1,41 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
namespace mlir {
namespace kernel_gen {
namespace transforms {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
struct MapParallelLoopsPass : MapParallelLoopsPassBase<MapParallelLoopsPass> {
void runOnFunction() override {
mlir::greedilyMapParallelSCFToGPU(getFunction().getBody());
}
};
} // namespace
std::unique_ptr<mlir::FunctionPass> CreateMapParallelLoopsPass() {
return std::make_unique<MapParallelLoopsPass>();
}
} // namespace transforms
} // namespace kernel_gen
} // namespace mlir

View File

@ -52,7 +52,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTFKernelToLLVMPass(
// Pass to tranform shape computations in shape dialect to standard and scf
// using memref descriptors.
std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateShapeToDescriptorsPass();
// Pass to tranform hlo-level computations on values to their corresponding
// parts on buffers.
@ -80,6 +80,9 @@ std::unique_ptr<FunctionPass> CreatePropagateShapeKnowledgeToKernels();
// Pass to print content of memrefs.
std::unique_ptr<FunctionPass> CreateEmbedMemRefPrintsPass();
/// Greedily maps loops to GPU hardware dimensions.
std::unique_ptr<mlir::FunctionPass> CreateMapParallelLoopsPass();
} // namespace transforms
#define GEN_PASS_REGISTRATION

View File

@ -98,4 +98,13 @@ def EmbedMemRefPrintsPass : FunctionPass<"embed-memref-prints"> {
let constructor = "transforms::CreateEmbedMemRefPrintsPass()";
}
def MapParallelLoopsPass
: FunctionPass<"map-parallel-loops-to-gpu"> {
let summary = "Greedily maps loops to GPU hardware dimensions.";
let constructor = "transforms::CreateMapParallelLoopsPass()";
let description = [{
Greedily maps loops to GPU hardware dimensions.
}];
}
#endif // TF_KERNEL_GEN_PASSES