Move MapParallelLoopsPass to kernel _generator directory.
PiperOrigin-RevId: 348611655 Change-Id: Id2341c80f639018e5b4c92ddf82e722763d22917
This commit is contained in:
parent
75e4d1a904
commit
f72495d54c
@ -157,7 +157,8 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef<uint32_t> tile_sizes,
|
|||||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass());
|
||||||
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass());
|
||||||
// Greedily map the remaining loop to GPU hardware dimensions.
|
// Greedily map the remaining loop to GPU hardware dimensions.
|
||||||
pm.addNestedPass<::mlir::FuncOp>(xla::mlir_gpu::createMapParallelLoopsPass());
|
pm.addNestedPass<::mlir::FuncOp>(
|
||||||
|
mlir::kernel_gen::transforms::CreateMapParallelLoopsPass());
|
||||||
|
|
||||||
// Now lower the shape computations, bufferize all remaining ops and insert
|
// Now lower the shape computations, bufferize all remaining ops and insert
|
||||||
// deallocs.
|
// deallocs.
|
||||||
|
@ -77,6 +77,7 @@ cc_library(
|
|||||||
"embed_memref_prints.cc",
|
"embed_memref_prints.cc",
|
||||||
"embed_tf_framework_pass.cc",
|
"embed_tf_framework_pass.cc",
|
||||||
"gpu_kernel_to_blob_pass.cc",
|
"gpu_kernel_to_blob_pass.cc",
|
||||||
|
"map_parallel_loops_to_gpu.cc",
|
||||||
"parallel_loops_to_sequential.cc",
|
"parallel_loops_to_sequential.cc",
|
||||||
"same_shape_propagation.cc",
|
"same_shape_propagation.cc",
|
||||||
"shape_to_descriptors_pass.cc",
|
"shape_to_descriptors_pass.cc",
|
||||||
@ -103,6 +104,7 @@ cc_library(
|
|||||||
":tf_framework_legalize_to_llvm",
|
":tf_framework_legalize_to_llvm",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:GPUDialect",
|
"@llvm-project//mlir:GPUDialect",
|
||||||
|
"@llvm-project//mlir:GPUTransforms",
|
||||||
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
|
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LLVMDialect",
|
"@llvm-project//mlir:LLVMDialect",
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace kernel_gen {
|
||||||
|
namespace transforms {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#define GEN_PASS_CLASSES
|
||||||
|
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
|
||||||
|
|
||||||
|
struct MapParallelLoopsPass : MapParallelLoopsPassBase<MapParallelLoopsPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
mlir::greedilyMapParallelSCFToGPU(getFunction().getBody());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::FunctionPass> CreateMapParallelLoopsPass() {
|
||||||
|
return std::make_unique<MapParallelLoopsPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace transforms
|
||||||
|
} // namespace kernel_gen
|
||||||
|
} // namespace mlir
|
@ -52,7 +52,7 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateTFKernelToLLVMPass(
|
|||||||
|
|
||||||
// Pass to tranform shape computations in shape dialect to standard and scf
|
// Pass to tranform shape computations in shape dialect to standard and scf
|
||||||
// using memref descriptors.
|
// using memref descriptors.
|
||||||
std::unique_ptr<OperationPass<ModuleOp> > CreateShapeToDescriptorsPass();
|
std::unique_ptr<OperationPass<ModuleOp>> CreateShapeToDescriptorsPass();
|
||||||
|
|
||||||
// Pass to tranform hlo-level computations on values to their corresponding
|
// Pass to tranform hlo-level computations on values to their corresponding
|
||||||
// parts on buffers.
|
// parts on buffers.
|
||||||
@ -80,6 +80,9 @@ std::unique_ptr<FunctionPass> CreatePropagateShapeKnowledgeToKernels();
|
|||||||
// Pass to print content of memrefs.
|
// Pass to print content of memrefs.
|
||||||
std::unique_ptr<FunctionPass> CreateEmbedMemRefPrintsPass();
|
std::unique_ptr<FunctionPass> CreateEmbedMemRefPrintsPass();
|
||||||
|
|
||||||
|
/// Greedily maps loops to GPU hardware dimensions.
|
||||||
|
std::unique_ptr<mlir::FunctionPass> CreateMapParallelLoopsPass();
|
||||||
|
|
||||||
} // namespace transforms
|
} // namespace transforms
|
||||||
|
|
||||||
#define GEN_PASS_REGISTRATION
|
#define GEN_PASS_REGISTRATION
|
||||||
|
@ -98,4 +98,13 @@ def EmbedMemRefPrintsPass : FunctionPass<"embed-memref-prints"> {
|
|||||||
let constructor = "transforms::CreateEmbedMemRefPrintsPass()";
|
let constructor = "transforms::CreateEmbedMemRefPrintsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def MapParallelLoopsPass
|
||||||
|
: FunctionPass<"map-parallel-loops-to-gpu"> {
|
||||||
|
let summary = "Greedily maps loops to GPU hardware dimensions.";
|
||||||
|
let constructor = "transforms::CreateMapParallelLoopsPass()";
|
||||||
|
let description = [{
|
||||||
|
Greedily maps loops to GPU hardware dimensions.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // TF_KERNEL_GEN_PASSES
|
#endif // TF_KERNEL_GEN_PASSES
|
||||||
|
Loading…
Reference in New Issue
Block a user