diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index ba4749efbed..63fec2498a5 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -902,6 +902,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { ConvolutionAttributes.attributes); let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; } def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 9b1b1268d33..57fdfb6cfdb 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -958,6 +958,17 @@ def HLO_PrecisionConfigAttr: OptionalAttr< TypedArrayAttrBase>; +def BoolElementsAttr : + ElementsAttrBase< + And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">, + CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, + "constant boolean vector/tensor attribute"> { + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + def ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. @@ -968,6 +979,8 @@ def ConvolutionAttributes { OptionalAttr:$lhs_dilation, // Default value: one for each of the spatial dimension. OptionalAttr:$rhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_reversal, ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count, @@ -983,6 +996,14 @@ class BASE_HLO_ConvOp { See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. }]; + + code extraClassDeclaration = [{ + bool hasWindowReversal() { + auto reversal = window_reversalAttr(); + return reversal && llvm::any_of(reversal.getBoolValues(), + [](bool v) { return v; }); + } + }]; } class BASE_HLO_CopyOp { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 31b7a60e7a5..9beda1388bd 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -243,7 +243,8 @@ struct ConvToLinalgConverter : public OpConversionPattern { } // TODO: LHS dilation for deconvolution not supported yet. - if (op.lhs_dilation()) { + // TODO(jurahul): Window reversal is not supported yet. + if (op.lhs_dilation() || op.hasWindowReversal()) { return failure(); } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir index 35bf59b057e..83327d4d7ef 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_gpu_ops.mlir @@ -103,7 +103,8 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6 precision_config = [], result_scale = 1.000000e+00 : f64, rhs_dilation = dense<1> : tensor<2xi64>, - window_strides = dense<1> : tensor<2xi64>} + window_strides = dense<1> : tensor<2xi64>, + window_reversal = dense: tensor<2xi1>} : (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 80c0180f604..d6fd6600fa9 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -113,6 +113,7 @@ StatusOr MlirHloBuilder::ConvGeneralDilatedInternal( ConvertPadding(padding, &builder_), GetI64ElementsAttr(lhs_dilation, &builder_), GetI64ElementsAttr(rhs_dilation, &builder_), + /*window_reversal=*/nullptr, ConvertConvDimensionNumbers(dimension_numbers, &builder_), builder_.getI64IntegerAttr(feature_group_count), builder_.getI64IntegerAttr(batch_group_count), config_attr); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 2495c341994..34345fe7f3d 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -737,6 +737,26 @@ LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) { return failure(); } +LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) { + // XLA client builder API does not support generating convolution instructions + // with window reversal. + if (op.hasWindowReversal()) return failure(); + auto& value_map = *ctx.values; + xla::XlaOp lhs, rhs; + if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure(); + if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure(); + xla::XlaOp xla_result = xla::ConvGeneralDilated( + lhs, rhs, Convert_window_strides(op.window_strides()), + Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()), + Convert_rhs_dilation(op.rhs_dilation()), + Convert_dimension_numbers(op.dimension_numbers()), + Convertuint64_t(op.feature_group_count()), + Convertuint64_t(op.batch_group_count()), + Unwrap(Convert_precision_config(op.precision_config()))); + value_map[op] = xla_result; + return mlir::success(); +} + LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaOp operand; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 0a8f7ddbab6..2d2e7197d0e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -4265,6 +4265,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { &rewriter), /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), GetI64ElementsAttr(rhs_dilation, &rewriter), + /*window_reversal=*/nullptr, ConvDimensionNumbers::get( /*input_batch_dimension=*/batch_dim_attr, /*input_feature_dimension=*/feature_dim_attr, @@ -4479,6 +4480,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, &rewriter), GetI64ElementsAttr(rhs_dilation, &rewriter), + /*window_reversal=*/nullptr, ConvDimensionNumbers::get( // Swap batch_dim and feature_dim in the activations. /*input_batch_dimension=*/feature_dim_attr,