[MLIR:HLO] Add window_reversal attribute to convolution attributes.
- Add this attribute to match the corresponding XLA HLO attribute on convolution operations. - A true value indicates a reversal of the corresponding kernel spatial dimension. - Since XLA builder does not support this attribute, use a custom HLO converted to map from mlir::mhlo::ConvOp to XLA. PiperOrigin-RevId: 346891737 Change-Id: I5c3aa4f6229d7f17970ae36b88bfbfc1bd137b08
This commit is contained in:
parent
ece423eb03
commit
2948461bab
@ -902,6 +902,7 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp {
|
|||||||
ConvolutionAttributes.attributes);
|
ConvolutionAttributes.attributes);
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {
|
def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp {
|
||||||
|
@ -958,6 +958,17 @@ def HLO_PrecisionConfigAttr:
|
|||||||
OptionalAttr<
|
OptionalAttr<
|
||||||
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
|
TypedArrayAttrBase<HLO_PrecisionAttr, "Precision Config attribute">>;
|
||||||
|
|
||||||
|
def BoolElementsAttr :
|
||||||
|
ElementsAttrBase<
|
||||||
|
And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">,
|
||||||
|
CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>,
|
||||||
|
"constant boolean vector/tensor attribute"> {
|
||||||
|
let storageType = [{ ::mlir::DenseElementsAttr }];
|
||||||
|
let returnType = [{ ::mlir::DenseElementsAttr }];
|
||||||
|
|
||||||
|
let convertFromStorage = "$_self";
|
||||||
|
}
|
||||||
|
|
||||||
def ConvolutionAttributes {
|
def ConvolutionAttributes {
|
||||||
dag attributes = (ins
|
dag attributes = (ins
|
||||||
// Default value: one for each of the spatial dimension.
|
// Default value: one for each of the spatial dimension.
|
||||||
@ -968,6 +979,8 @@ def ConvolutionAttributes {
|
|||||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||||
// Default value: one for each of the spatial dimension.
|
// Default value: one for each of the spatial dimension.
|
||||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<BoolElementsAttr>:$window_reversal,
|
||||||
ConvDimensionNumbers:$dimension_numbers,
|
ConvDimensionNumbers:$dimension_numbers,
|
||||||
I64Attr:$feature_group_count,
|
I64Attr:$feature_group_count,
|
||||||
I64Attr:$batch_group_count,
|
I64Attr:$batch_group_count,
|
||||||
@ -983,6 +996,14 @@ class BASE_HLO_ConvOp {
|
|||||||
|
|
||||||
See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
|
See https://www.tensorflow.org/xla/operation_semantics#conv_convolution.
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
code extraClassDeclaration = [{
|
||||||
|
bool hasWindowReversal() {
|
||||||
|
auto reversal = window_reversalAttr();
|
||||||
|
return reversal && llvm::any_of(reversal.getBoolValues(),
|
||||||
|
[](bool v) { return v; });
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
class BASE_HLO_CopyOp {
|
class BASE_HLO_CopyOp {
|
||||||
|
@ -243,7 +243,8 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: LHS dilation for deconvolution not supported yet.
|
// TODO: LHS dilation for deconvolution not supported yet.
|
||||||
if (op.lhs_dilation()) {
|
// TODO(jurahul): Window reversal is not supported yet.
|
||||||
|
if (op.lhs_dilation() || op.hasWindowReversal()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,7 +103,8 @@ func @conv_backinput(%input : memref<4x5x16x16xf64>, %filter : memref<5x3x7x7xf6
|
|||||||
precision_config = [],
|
precision_config = [],
|
||||||
result_scale = 1.000000e+00 : f64,
|
result_scale = 1.000000e+00 : f64,
|
||||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||||
window_strides = dense<1> : tensor<2xi64>}
|
window_strides = dense<1> : tensor<2xi64>,
|
||||||
|
window_reversal = dense<true>: tensor<2xi1>}
|
||||||
: (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> ()
|
: (memref<4x5x16x16xf64>, memref<5x3x7x7xf64>, memref<4x3x16x16xf64>, memref<32xui8>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -113,6 +113,7 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
|
|||||||
ConvertPadding(padding, &builder_),
|
ConvertPadding(padding, &builder_),
|
||||||
GetI64ElementsAttr(lhs_dilation, &builder_),
|
GetI64ElementsAttr(lhs_dilation, &builder_),
|
||||||
GetI64ElementsAttr(rhs_dilation, &builder_),
|
GetI64ElementsAttr(rhs_dilation, &builder_),
|
||||||
|
/*window_reversal=*/nullptr,
|
||||||
ConvertConvDimensionNumbers(dimension_numbers, &builder_),
|
ConvertConvDimensionNumbers(dimension_numbers, &builder_),
|
||||||
builder_.getI64IntegerAttr(feature_group_count),
|
builder_.getI64IntegerAttr(feature_group_count),
|
||||||
builder_.getI64IntegerAttr(batch_group_count), config_attr);
|
builder_.getI64IntegerAttr(batch_group_count), config_attr);
|
||||||
|
@ -737,6 +737,26 @@ LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ExportXlaOp(mlir::mhlo::ConvOp op, OpLoweringContext ctx) {
|
||||||
|
// XLA client builder API does not support generating convolution instructions
|
||||||
|
// with window reversal.
|
||||||
|
if (op.hasWindowReversal()) return failure();
|
||||||
|
auto& value_map = *ctx.values;
|
||||||
|
xla::XlaOp lhs, rhs;
|
||||||
|
if (failed(GetXlaOp(op.lhs(), value_map, &lhs, op))) return mlir::failure();
|
||||||
|
if (failed(GetXlaOp(op.rhs(), value_map, &rhs, op))) return mlir::failure();
|
||||||
|
xla::XlaOp xla_result = xla::ConvGeneralDilated(
|
||||||
|
lhs, rhs, Convert_window_strides(op.window_strides()),
|
||||||
|
Convert_padding(op.padding()), Convert_lhs_dilation(op.lhs_dilation()),
|
||||||
|
Convert_rhs_dilation(op.rhs_dilation()),
|
||||||
|
Convert_dimension_numbers(op.dimension_numbers()),
|
||||||
|
Convertuint64_t(op.feature_group_count()),
|
||||||
|
Convertuint64_t(op.batch_group_count()),
|
||||||
|
Unwrap(Convert_precision_config(op.precision_config())));
|
||||||
|
value_map[op] = xla_result;
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
|
LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) {
|
||||||
auto& value_map = *ctx.values;
|
auto& value_map = *ctx.values;
|
||||||
xla::XlaOp operand;
|
xla::XlaOp operand;
|
||||||
|
@ -4265,6 +4265,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern<OpTy> {
|
|||||||
&rewriter),
|
&rewriter),
|
||||||
/*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
|
/*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter),
|
||||||
GetI64ElementsAttr(rhs_dilation, &rewriter),
|
GetI64ElementsAttr(rhs_dilation, &rewriter),
|
||||||
|
/*window_reversal=*/nullptr,
|
||||||
ConvDimensionNumbers::get(
|
ConvDimensionNumbers::get(
|
||||||
/*input_batch_dimension=*/batch_dim_attr,
|
/*input_batch_dimension=*/batch_dim_attr,
|
||||||
/*input_feature_dimension=*/feature_dim_attr,
|
/*input_feature_dimension=*/feature_dim_attr,
|
||||||
@ -4479,6 +4480,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern<OpTy> {
|
|||||||
GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
|
GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1,
|
||||||
&rewriter),
|
&rewriter),
|
||||||
GetI64ElementsAttr(rhs_dilation, &rewriter),
|
GetI64ElementsAttr(rhs_dilation, &rewriter),
|
||||||
|
/*window_reversal=*/nullptr,
|
||||||
ConvDimensionNumbers::get(
|
ConvDimensionNumbers::get(
|
||||||
// Swap batch_dim and feature_dim in the activations.
|
// Swap batch_dim and feature_dim in the activations.
|
||||||
/*input_batch_dimension=*/feature_dim_attr,
|
/*input_batch_dimension=*/feature_dim_attr,
|
||||||
|
Loading…
Reference in New Issue
Block a user