Add lowering of affine operations to the LHLO->GPU->LLVM pipeline.

These operations are created when tiling dynamically shaped HLO operations via linalg. This does not happen in the flow that uses the XLA frontend, so we missed this.

PiperOrigin-RevId: 294189195
Change-Id: I7f847ce682c048a312a918fed2c8c6207f196d54
This commit is contained in:
Stephan Herhut 2020-02-10 03:50:54 -08:00 committed by TensorFlower Gardener
parent 1e8d9114a7
commit 04f2870814
2 changed files with 4 additions and 2 deletions

View File

@ -152,6 +152,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@llvm-project//mlir:AffineDialectRegistration",
"@llvm-project//mlir:AffineToStandardTransforms",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:GPUDialectRegistration",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "absl/memory/memory.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // TF:llvm-project
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // TF:llvm-project
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // TF:llvm-project
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" // TF:llvm-project
@ -330,7 +331,7 @@ class LowerToNVVMPass
::mlir::populateLinalgToLLVMConversionPatterns(converter, patterns,
&getContext());
::mlir::populateGpuToNVVMConversionPatterns(converter, patterns);
::mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
::mlir::ConversionTarget target(getContext());
target.addIllegalDialect<::mlir::gpu::GPUDialect>();
target.addIllegalOp<::mlir::LLVM::ExpOp>();
@ -339,7 +340,7 @@ class LowerToNVVMPass
// TODO(csigg): Remove once we support replacing non-root ops.
target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp,
::mlir::gpu::YieldOp>();
if (failed(applyPartialConversion(m, target, patterns, &converter))) {
if (failed(mlir::applyFullConversion(m, target, patterns, &converter))) {
signalPassFailure();
}
}