From c78647ccd1d2ada2eb20a8fec7e6b38412f3ca46 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 21 Jan 2021 01:21:23 -0800 Subject: [PATCH] [MLIR] Migrate TF from STD complex ops to ComplexDialect. PiperOrigin-RevId: 352966408 Change-Id: I1f422862f0cc1bf33fb60131dba06cf47e0c97ac --- tensorflow/compiler/mlir/hlo/BUILD | 1 + .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 15 +++++++++------ .../mhlo/transforms/legalize_to_linalg.cc | 12 +++++++----- .../mlir/hlo/tests/hlo-legalize-to-linalg.mlir | 4 ++-- .../mlir/hlo/tests/lhlo-legalize-to-linalg.mlir | 6 +++--- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../mlir/tensorflow/dialect_registration.h | 2 ++ .../mlir/tools/kernel_gen/transforms/BUILD | 2 ++ .../kernel_gen/transforms/bufferize_pass.cc | 17 +++++++---------- .../transforms/kernel_lowering_passes.cc | 3 +++ .../transforms/tf_kernel_to_llvm_pass.cc | 8 ++++++-- 11 files changed, 43 insertions(+), 28 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index f83b9860609..20967e63bb5 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -583,6 +583,7 @@ cc_library( ":lhlo", ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 7ce33fb1ac9..77600e040e9 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinTypes.h" @@ -41,7 +42,7 @@ template <> struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; - using COp = ::mlir::AddCFOp; + using COp = ::mlir::complex::AddOp; }; template <> struct LhloToScalarOp { @@ -67,7 +68,7 @@ template <> struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; - using COp = ::mlir::SubCFOp; + using COp = ::mlir::complex::SubOp; }; // Alias for the map from LHLO binary op type to STD floating-point op type. @@ -261,8 +262,8 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, - b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, + args, b); } template <> @@ -270,7 +271,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); } template <> @@ -278,7 +280,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); } template <> diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 2cee0b4981f..7f630a5777c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1298,8 +1298,8 @@ struct LhloLegalizeToLinalgPass void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); @@ -1312,14 +1312,16 @@ struct LhloLegalizeToLinalgPass struct HloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry.insert(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 4b2f354bdf0..ade42d8e950 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -33,7 +33,7 @@ func @integer_add(%lhs: tensor<2x2xi32>, func @complex_add(%lhs: tensor<2x2xcomplex>, %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { // CHECK: linalg.generic - // CHECK: addcf + // CHECK: complex.add %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) -> tensor<2x2xcomplex> return %0 : tensor<2x2xcomplex> @@ -128,7 +128,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, func @complex_sub(%lhs: tensor<2x2xcomplex>, %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { // CHECK: linalg.generic - // CHECK: subcf + // CHECK: complex.sub %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) -> tensor<2x2xcomplex> return %0 : tensor<2x2xcomplex> diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index 3c757959137..dfcc3c726dd 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -700,7 +700,7 @@ func @complex(%real: memref<2x2xf32>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex): -// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex +// CHECK-NEXT: %[[RESULT:.*]] = complex.create %[[RE]], %[[IM]] : complex // CHECK-NEXT: linalg.yield %[[RESULT]] : complex // ----- @@ -714,7 +714,7 @@ func @real(%cplx: memref<2x2xcomplex>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[REAL_OUT:.*]]: f32): -// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: %[[REAL:.*]] = complex.re %[[CPLX_IN:.*]] : complex // CHECK-NEXT: linalg.yield %[[REAL]] : f32 // ----- @@ -728,7 +728,7 @@ func @imag(%cplx: memref<2x2xcomplex>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[IMAG_OUT:.*]]: f32): -// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: %[[IMAG:.*]] = complex.im %[[CPLX_IN:.*]] : complex // CHECK-NEXT: linalg.yield %[[IMAG]] : f32 // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 47a6c59b295..f1dad02836c 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -649,6 +649,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:CallOpInterfacesIncGen", + "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", diff --git a/tensorflow/compiler/mlir/tensorflow/dialect_registration.h b/tensorflow/compiler/mlir/tensorflow/dialect_registration.h index a63bfd154ab..f81f7afa530 100644 --- a/tensorflow/compiler/mlir/tensorflow/dialect_registration.h +++ b/tensorflow/compiler/mlir/tensorflow/dialect_registration.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -28,6 +29,7 @@ namespace mlir { // intended for tools that need to register dialects before parsing .mlir files. inline void RegisterAllTensorFlowDialects(DialectRegistry ®istry) { registry.insert(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 9c09755b9ee..21b43241b29 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -95,6 +95,8 @@ cc_library( ":tf_framework_legalize_to_llvm", "@llvm-project//llvm:Support", "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:ComplexToLLVM", "@llvm-project//mlir:Affine", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:Analysis", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc index f84f59a9bb6..45ba5e0c5b1 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project #include "mlir/Dialect/SCF/SCF.h" // from @llvm-project @@ -111,10 +112,8 @@ struct HloBufferizePass : public HloBufferizePassBase { OwningRewritePatternList patterns; auto& context = getContext(); ConversionTarget target(context); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalDialect(); CustomBufferizeTypeConverter converter; @@ -162,11 +161,10 @@ struct FinalBufferizePass : public FinalBufferizePassBase { void runOnOperation() override { auto& context = getContext(); ConversionTarget target(context); - target.addLegalDialect(); + target.addLegalDialect< + complex::ComplexDialect, scf::SCFDialect, StandardOpsDialect, + tensor::TensorDialect, tf_framework::TFFrameworkDialect, AffineDialect, + shape::ShapeDialect, lmhlo::LmhloDialect, linalg::LinalgDialect>(); target.addLegalOp(); target.addIllegalDialect(); @@ -212,4 +210,3 @@ std::unique_ptr > CreateFinalBufferizePass() { } // namespace transforms } // namespace kernel_gen } // namespace mlir - diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_lowering_passes.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_lowering_passes.cc index 1c82e97ce53..de33556c336 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_lowering_passes.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_lowering_passes.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project @@ -51,6 +52,7 @@ class GpuKernelToNVVMPass LLVMTypeConverter converter(m.getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateGpuToNVVMConversionPatterns(converter, patterns); + populateComplexToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); configureGpuToNVVMConversionLegality(target); if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) { @@ -75,6 +77,7 @@ class GpuKernelToROCDLPass LLVMTypeConverter converter(m.getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateGpuToROCDLConversionPatterns(converter, patterns); + populateComplexToLLVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); configureGpuToROCDLConversionLegality(target); if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc index 454f01f0743..4a764af6355 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc @@ -16,9 +16,11 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project @@ -253,6 +255,7 @@ class TFKernelToLLVMPass : public TFKernelToLLVMPassBase { populateStdExpandOpsPatterns(ctx, patterns); populateStdToLLVMConversionPatterns(type_converter, patterns); + populateComplexToLLVMConversionPatterns(type_converter, patterns); tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter, &patterns); patterns.insert( @@ -260,8 +263,9 @@ class TFKernelToLLVMPass : public TFKernelToLLVMPassBase { // Set target. ConversionTarget target(*ctx); target.addLegalDialect(); - target.addIllegalDialect(); + target + .addIllegalDialect(); target.addIllegalOp(); // Mark modules as legal. target.addLegalOp();