diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 736651b5022..972d5ea2c51 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -28,6 +28,8 @@ package_group( exports_files(["ir/hlo_ops.td"]) +exports_files(["ir/lhlo_ops.td"]) + filegroup( name = "hlo_ops_td_files", srcs = [ @@ -87,6 +89,8 @@ gentbl( tbl_outs = [ ("-gen-op-decls", "ir/lhlo_ops.h.inc"), ("-gen-op-defs", "ir/lhlo_ops.cc.inc"), + ("-gen-struct-attr-decls", "ir/lhlo_structs.h.inc"), + ("-gen-struct-attr-defs", "ir/lhlo_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/lhlo_ops.td", @@ -362,6 +366,7 @@ cc_library( ":map_hlo_to_lhlo_op", "@com_google_absl//absl/memory", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc index 6f9b39377af..24cffa756ec 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc" namespace mlir { +#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.cc.inc" namespace xla_lhlo { XlaLhloDialect::XlaLhloDialect(MLIRContext *context) diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h index 3827e8a7a4e..6ea5e2522c2 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -33,6 +33,8 @@ limitations under the License. namespace mlir { class OpBuilder; +#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.h.inc" + namespace xla_lhlo { class XlaLhloDialect : public Dialect { diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index d9f3648bb09..6ba9935d85e 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -407,11 +407,39 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { ); } +// TODO(bondhugula): Make this struct dialect independent so that it can be +// shared between the HLO and LHLO dialects. +def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { let arguments = (ins Arg:$lhs, Arg:$rhs, - Arg:$output + Arg:$output, + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_strides, + // Default value: zero for each of the spatial dimension. + OptionalAttr:$padding, + // Default value: one for each of the spatial dimension. + OptionalAttr:$lhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$rhs_dilation, + ConvDimensionNumbers:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + HLO_PrecisionConfigAttr:$precision_config ); } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir index 38ea818aea8..aca4bf5865d 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-lhlo.mlir @@ -432,7 +432,39 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], // CHECK-SAME: %[[RESULT:.*]]: [[TYPE]]) // CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () - %dot = "xla_hlo.dot"(%arg0, %arg0) - : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> - return %dot : tensor<1024x1024xf32> - } + %dot = "xla_hlo.dot"(%arg0, %arg0) + : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %dot : tensor<1024x1024xf32> +} + +// ----- + +// CHECK-LABEL: func @conv +func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { + %c0 = constant 0 : index + // CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> + // CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) + // CHECK-SAME: padding = dense<[ + // CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> + // CHECK-SAME: rhs_dilation = dense<[1, 2]> + // CHECK-SAME: window_strides = dense<[2, 1]> + %out = "xla_hlo.convolution"(%filter, %input) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64> + } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + return %out : tensor<3x5x5x4xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index 626e905695c..ce5d0d28076 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -469,7 +469,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32 +// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32 // CHECK-NEXT: %[[RESULT:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 @@ -649,3 +649,27 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { return } // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] + + +// ----- + +func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) { + %c0 = constant 0 : index + %0 = alloc() : memref<3x5x5x4xf32> + // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) + // CHECK-SAME: dilations = [1, 2] + // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> + // CHECK-SAME: strides = [2, 1]} + // With all atributes explicitly specified. + "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + + // Dilation left unspecified, sets default dilation since linalg expects it. + // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) + // CHECK-SAME: dilations = [1, 1] + // Padding is not set if it's zero. + // CHECK-NOT: padding + "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + + "xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () + "xla_lhlo.terminator"() : () -> () +} diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 6f5bafef4c0..45aed7e10ff 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -424,6 +424,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h index 21b954a3eb4..4b9397795a1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -45,6 +45,7 @@ MAP_HLO_TO_LHLO(CeilOp); MAP_HLO_TO_LHLO(ConstOp); MAP_HLO_TO_LHLO(CompareOp); MAP_HLO_TO_LHLO(ComplexOp); +MAP_HLO_TO_LHLO(ConvOp); MAP_HLO_TO_LHLO(ConvertOp); MAP_HLO_TO_LHLO(CopyOp); MAP_HLO_TO_LHLO(CosOp); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 2b496677d62..fd0c9541e7c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -192,6 +192,108 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// xla_lhlo.convolution conversion pattern. +//===----------------------------------------------------------------------===// + +/// Converts xla_lhlo.convolution operation to a linalg.conv op. +struct ConvToLinalgConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + // This code has been adapted from IREE's + // (https://github.com/google/iree/) xla_hlo -> linalg conversion. + LogicalResult matchAndRewrite( + xla_lhlo::ConvOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + // Check validity of dimension information. + if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers = + op.dimension_numbers()) { + const int inputSpatialRank = + llvm::size(dimensionNumbers.input_spatial_dimensions()); + // The dimensions for input should follow the order of + // batch_count, spatial_dims..., input_feature_count. + if (dimensionNumbers.input_batch_dimension().getInt() != 0 || + dimensionNumbers.input_feature_dimension().getInt() != + (inputSpatialRank + 1)) + return failure(); + + const int kernelSpatialRank = + llvm::size(dimensionNumbers.kernel_spatial_dimensions()); + // The dimensions for filter should follow the order of + // spatial_dims..., input_feature_count, num_output_feature_count. + if (dimensionNumbers.kernel_input_feature_dimension().getInt() != + kernelSpatialRank || + dimensionNumbers.kernel_output_feature_dimension().getInt() != + (kernelSpatialRank + 1)) + return failure(); + + const int outputSpatialRank = + llvm::size(dimensionNumbers.output_spatial_dimensions()); + // The dimensions for output should follow the order of + // batch_count, spatial_dims.., output_feature_count. + if (dimensionNumbers.output_batch_dimension().getInt() != 0 || + dimensionNumbers.output_feature_dimension().getInt() != + (outputSpatialRank + 1)) + return failure(); + + if (inputSpatialRank != outputSpatialRank || + inputSpatialRank != kernelSpatialRank) + return failure(); + + auto inputSpatialDim = + dimensionNumbers.input_spatial_dimensions().begin(); + auto kernelSpatialDim = + dimensionNumbers.kernel_spatial_dimensions().begin(); + auto outputSpatialDim = + dimensionNumbers.output_spatial_dimensions().begin(); + // Check if spatial dims are ordered correctly. + for (int i = 0; i < inputSpatialRank; ++i) { + const int dim = i + 1; + if ((*inputSpatialDim++).getZExtValue() != dim || + (*outputSpatialDim++).getZExtValue() != dim || + (*kernelSpatialDim++).getZExtValue() != i) + return failure(); + } + } + + // TODO: LHS dilation for deconvolution not supported yet. + if (op.lhs_dilation()) { + return failure(); + } + + llvm::SmallVector strides; + if (auto windowStrides = op.window_strides()) { + auto range = windowStrides->getAttributeValues(); + strides.assign(range.begin(), range.end()); + } + auto stridesArg = ArrayAttr::get(strides, op.getContext()); + + llvm::SmallVector dilation; + if (auto rhsDilation = op.rhs_dilation()) { + auto range = rhsDilation->getAttributeValues(); + dilation.assign(range.begin(), range.end()); + } else { + // Default dilation of 1. + dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); + } + auto dilationArg = ArrayAttr::get(dilation, op.getContext()); + + // Set padding only if it is non-zero. + DenseIntElementsAttr padding = op.paddingAttr(); + if (!padding || !llvm::any_of(padding.getValues(), [](APInt intVal) { + return !intVal.isNullValue(); + })) { + padding = nullptr; + } + + // The order of input and filter are switched with linalg.conv. + rewriter.replaceOpWithNewOp( + op, args[1], args[0], args[2], stridesArg, dilationArg, padding); + return success(); + } +}; + /// Base class for lowering xla operations that have one operand and one result, /// and are semantically equivalent to a copy of the input to the output (like /// transpose, some reshape, etc.). The derived classes need to provide a method @@ -641,6 +743,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, patterns->insert, BroadcastInDimConverter, ConstConverter, + ConvToLinalgConverter, IotaConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter,