[MLIR] Migrate TF from STD complex ops to ComplexDialect.
PiperOrigin-RevId: 352966408 Change-Id: I1f422862f0cc1bf33fb60131dba06cf47e0c97ac
This commit is contained in:
parent
4da05d7712
commit
c78647ccd1
@ -583,6 +583,7 @@ cc_library(
|
||||
":lhlo",
|
||||
":map_hlo_to_lhlo_op",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:ComplexDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
@ -41,7 +42,7 @@ template <>
|
||||
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||
using FOp = ::mlir::AddFOp;
|
||||
using IOp = ::mlir::AddIOp;
|
||||
using COp = ::mlir::AddCFOp;
|
||||
using COp = ::mlir::complex::AddOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||
@ -67,7 +68,7 @@ template <>
|
||||
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||
using FOp = ::mlir::SubFOp;
|
||||
using IOp = ::mlir::SubIOp;
|
||||
using COp = ::mlir::SubCFOp;
|
||||
using COp = ::mlir::complex::SubOp;
|
||||
};
|
||||
|
||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||
@ -261,8 +262,8 @@ template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
|
||||
b);
|
||||
return MapLhloOpToStdScalarOpImpl<complex::CreateOp>{}(loc, result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -270,7 +271,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
|
||||
return MapLhloOpToStdScalarOpImpl<complex::ReOp>{}(loc, result_types, args,
|
||||
b);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -278,7 +280,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
|
||||
return MapLhloOpToStdScalarOpImpl<complex::ImOp>{}(loc, result_types, args,
|
||||
b);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -1298,8 +1298,8 @@ struct LhloLegalizeToLinalgPass
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
AffineDialect>();
|
||||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||
StandardOpsDialect, AffineDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
@ -1312,14 +1312,16 @@ struct LhloLegalizeToLinalgPass
|
||||
struct HloLegalizeToLinalgPass
|
||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<linalg::LinalgDialect, scf::SCFDialect>();
|
||||
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
|
||||
complex::ComplexDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
tensor::TensorDialect, scf::SCFDialect>();
|
||||
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||
StandardOpsDialect, tensor::TensorDialect,
|
||||
scf::SCFDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
|
@ -33,7 +33,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
||||
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
|
||||
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addcf
|
||||
// CHECK: complex.add
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||
return %0 : tensor<2x2xcomplex<f32>>
|
||||
@ -128,7 +128,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
|
||||
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subcf
|
||||
// CHECK: complex.sub
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||
return %0 : tensor<2x2xcomplex<f32>>
|
||||
|
@ -700,7 +700,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex<f32>):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex<f32>
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.create %[[RE]], %[[IM]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
@ -714,7 +714,7 @@ func @real(%cplx: memref<2x2xcomplex<f32>>,
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[REAL_OUT:.*]]: f32):
|
||||
// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex<f32>
|
||||
// CHECK-NEXT: %[[REAL:.*]] = complex.re %[[CPLX_IN:.*]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[REAL]] : f32
|
||||
|
||||
// -----
|
||||
@ -728,7 +728,7 @@ func @imag(%cplx: memref<2x2xcomplex<f32>>,
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[IMAG_OUT:.*]]: f32):
|
||||
// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex<f32>
|
||||
// CHECK-NEXT: %[[IMAG:.*]] = complex.im %[[CPLX_IN:.*]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[IMAG]] : f32
|
||||
|
||||
// -----
|
||||
|
@ -649,6 +649,7 @@ cc_library(
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:CallOpInterfacesIncGen",
|
||||
"@llvm-project//mlir:ComplexDialect",
|
||||
"@llvm-project//mlir:ControlFlowInterfaces",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_
|
||||
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
@ -28,6 +29,7 @@ namespace mlir {
|
||||
// intended for tools that need to register dialects before parsing .mlir files.
|
||||
inline void RegisterAllTensorFlowDialects(DialectRegistry ®istry) {
|
||||
registry.insert<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
|
||||
mlir::complex::ComplexDialect,
|
||||
mlir::tf_device::TensorFlowDeviceDialect,
|
||||
mlir::tf_executor::TensorFlowExecutorDialect,
|
||||
mlir::tf_saved_model::TensorFlowSavedModelDialect>();
|
||||
|
@ -95,6 +95,8 @@ cc_library(
|
||||
":tf_framework_legalize_to_llvm",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TransformUtils",
|
||||
"@llvm-project//mlir:ComplexDialect",
|
||||
"@llvm-project//mlir:ComplexToLLVM",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
|
||||
@ -111,10 +112,8 @@ struct HloBufferizePass : public HloBufferizePassBase<HloBufferizePass> {
|
||||
OwningRewritePatternList patterns;
|
||||
auto& context = getContext();
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<lmhlo::LmhloDialect, StandardOpsDialect>();
|
||||
target.addLegalDialect<tensor::TensorDialect>();
|
||||
target.addLegalDialect<complex::ComplexDialect, lmhlo::LmhloDialect,
|
||||
StandardOpsDialect, tensor::TensorDialect>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
|
||||
CustomBufferizeTypeConverter converter;
|
||||
@ -162,11 +161,10 @@ struct FinalBufferizePass : public FinalBufferizePassBase<FinalBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
auto& context = getContext();
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
|
||||
tensor::TensorDialect,
|
||||
tf_framework::TFFrameworkDialect, AffineDialect,
|
||||
shape::ShapeDialect, lmhlo::LmhloDialect,
|
||||
linalg::LinalgDialect>();
|
||||
target.addLegalDialect<
|
||||
complex::ComplexDialect, scf::SCFDialect, StandardOpsDialect,
|
||||
tensor::TensorDialect, tf_framework::TFFrameworkDialect, AffineDialect,
|
||||
shape::ShapeDialect, lmhlo::LmhloDialect, linalg::LinalgDialect>();
|
||||
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
|
||||
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
@ -212,4 +210,3 @@ std::unique_ptr<OperationPass<ModuleOp> > CreateFinalBufferizePass() {
|
||||
} // namespace transforms
|
||||
} // namespace kernel_gen
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
@ -51,6 +52,7 @@ class GpuKernelToNVVMPass
|
||||
LLVMTypeConverter converter(m.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateGpuToNVVMConversionPatterns(converter, patterns);
|
||||
populateComplexToLLVMConversionPatterns(converter, patterns);
|
||||
ConversionTarget target(getContext());
|
||||
configureGpuToNVVMConversionLegality(target);
|
||||
if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
|
||||
@ -75,6 +77,7 @@ class GpuKernelToROCDLPass
|
||||
LLVMTypeConverter converter(m.getContext());
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateGpuToROCDLConversionPatterns(converter, patterns);
|
||||
populateComplexToLLVMConversionPatterns(converter, patterns);
|
||||
ConversionTarget target(getContext());
|
||||
configureGpuToROCDLConversionLegality(target);
|
||||
if (failed(mlir::applyFullConversion(m, target, std::move(patterns)))) {
|
||||
|
@ -16,9 +16,11 @@ limitations under the License.
|
||||
#include <stdexcept>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Complex/IR/Complex.h" // from @llvm-project
|
||||
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
|
||||
@ -253,6 +255,7 @@ class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
|
||||
|
||||
populateStdExpandOpsPatterns(ctx, patterns);
|
||||
populateStdToLLVMConversionPatterns(type_converter, patterns);
|
||||
populateComplexToLLVMConversionPatterns(type_converter, patterns);
|
||||
tf_framework::PopulateTFFrameworkToLLVMConversionPatterns(&type_converter,
|
||||
&patterns);
|
||||
patterns.insert<ConvertLaunchFuncOpToTfRuntimeCallPattern>(
|
||||
@ -260,8 +263,9 @@ class TFKernelToLLVMPass : public TFKernelToLLVMPassBase<TFKernelToLLVMPass> {
|
||||
// Set target.
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addIllegalDialect<gpu::GPUDialect, StandardOpsDialect,
|
||||
tf_framework::TFFrameworkDialect>();
|
||||
target
|
||||
.addIllegalDialect<StandardOpsDialect, complex::ComplexDialect,
|
||||
gpu::GPUDialect, tf_framework::TFFrameworkDialect>();
|
||||
target.addIllegalOp<LLVM::DialectCastOp>();
|
||||
// Mark modules as legal.
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp, gpu::GPUModuleOp>();
|
||||
|
Loading…
x
Reference in New Issue
Block a user