diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 82f2885550f..006ae12b3b2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -57,7 +57,6 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/mlir_gpu:kernel_lowering", - "//tensorflow/compiler/xla/service/mlir_gpu:passes", "//tensorflow/core:lib", "//tensorflow/core/platform:cuda_libdevice_path", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 51c0f353da0..92f72e4ec4d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -57,7 +57,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h" -#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/path.h" @@ -145,7 +144,7 @@ Status LowerTFtoGPU(mlir::ModuleOp module, llvm::ArrayRef tile_sizes, pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); // Fuse the inner-most loops. pm.addNestedPass( - xla::mlir_gpu::createFuseInnerParallelLoopsPass()); + mlir::kernel_gen::transforms::CreateFuseInnerParallelLoopsPass()); // Run CSE to ensure that loads and stores to the same subview get // recognized as such. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index d328c5b28cc..16e956c8a01 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -76,6 +76,7 @@ cc_library( "bufferize_pass.cc", "embed_memref_prints.cc", "embed_tf_framework_pass.cc", + "fuse_inner_parallel_loops_pass.cc", "gpu_kernel_to_blob_pass.cc", "map_parallel_loops_to_gpu.cc", "parallel_loops_to_sequential.cc", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc new file mode 100644 index 00000000000..d9bb794fa97 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/fuse_inner_parallel_loops_pass.cc @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct FuseInnerParallelLoopsPass + : FuseInnerParallelLoopsPassBase { + void runOnFunction() override { + getFunction().walk([](mlir::scf::ParallelOp op) { + mlir::scf::naivelyFuseParallelOps(op.region()); + }); + } +}; + +} // namespace + +std::unique_ptr CreateFuseInnerParallelLoopsPass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 4a940dde997..203c9036936 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -83,6 +83,11 @@ std::unique_ptr CreateEmbedMemRefPrintsPass(); /// Greedily maps loops to GPU hardware dimensions. std::unique_ptr CreateMapParallelLoopsPass(); +/// We need to direct fusion to the inner loops. This cannot be done with +/// a passmanager alone ATM, as nested pass managers require operations to +/// be closed from above. +std::unique_ptr CreateFuseInnerParallelLoopsPass(); + } // namespace transforms #define GEN_PASS_REGISTRATION diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index af7613ecd3a..4304e3d2d0d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -107,4 +107,15 @@ def MapParallelLoopsPass }]; } +def FuseInnerParallelLoopsPass + : FunctionPass<"fuse-inner-parallel-loops"> { + let summary = "Limited pass to forward stores to loads."; + let constructor = "transforms::CreateFuseInnerParallelLoopsPass()"; + let description = [{ + Directs parallel loop fusion to the inner loops. This cannot be done with + a passmanager alone ATM, as nested pass managers require operations to + be closed from above. + }]; +} + #endif // TF_KERNEL_GEN_PASSES