Merge pull request #40064 from polymage-labs:conv
PiperOrigin-RevId: 315495471 Change-Id: I67ce7c82801a303d36bcd5583f009a1f6de948e5
This commit is contained in:
commit
56ef3f784b
|
@ -28,6 +28,8 @@ package_group(
|
|||
|
||||
exports_files(["ir/hlo_ops.td"])
|
||||
|
||||
exports_files(["ir/lhlo_ops.td"])
|
||||
|
||||
filegroup(
|
||||
name = "hlo_ops_td_files",
|
||||
srcs = [
|
||||
|
@ -88,6 +90,8 @@ gentbl(
|
|||
tbl_outs = [
|
||||
("-gen-op-decls", "ir/lhlo_ops.h.inc"),
|
||||
("-gen-op-defs", "ir/lhlo_ops.cc.inc"),
|
||||
("-gen-struct-attr-decls", "ir/lhlo_structs.h.inc"),
|
||||
("-gen-struct-attr-defs", "ir/lhlo_structs.cc.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "ir/lhlo_ops.td",
|
||||
|
@ -399,6 +403,7 @@ cc_library(
|
|||
":map_hlo_to_lhlo_op",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
|
|
|
@ -45,6 +45,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.cc.inc"
|
||||
namespace xla_lhlo {
|
||||
|
||||
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
|
||||
|
|
|
@ -33,6 +33,8 @@ limitations under the License.
|
|||
namespace mlir {
|
||||
class OpBuilder;
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.h.inc"
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
class XlaLhloDialect : public Dialect {
|
||||
|
|
|
@ -407,11 +407,39 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
|||
);
|
||||
}
|
||||
|
||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
||||
// shared between the HLO and LHLO dialects.
|
||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
||||
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||
|
||||
let description = "Structure of dimension information for conv op";
|
||||
}
|
||||
|
||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||
// Default value: zero for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$padding,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||
// Default value: one for each of the spatial dimension.
|
||||
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||
ConvDimensionNumbers:$dimension_numbers,
|
||||
I64Attr:$feature_group_count,
|
||||
I64Attr:$batch_group_count,
|
||||
HLO_PrecisionConfigAttr:$precision_config
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -435,4 +435,36 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
|||
%dot = "xla_hlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
return %dot : tensor<1024x1024xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @conv
|
||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// CHECK-SAME: padding = dense<[
|
||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// CHECK-SAME: window_strides = dense<[2, 1]>
|
||||
%out = "xla_hlo.convolution"(%filter, %input) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
||||
window_strides = dense<[2, 1]> : tensor<2xi64>
|
||||
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
|
||||
return %out : tensor<3x5x5x4xf32>
|
||||
}
|
||||
|
|
|
@ -691,3 +691,27 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
|
|||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: dilations = [1, 2]
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: strides = [2, 1]}
|
||||
// With all atributes explicitly specified.
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
// Dilation left unspecified, sets default dilation since linalg expects it.
|
||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: dilations = [1, 1]
|
||||
// Padding is not set if it's zero.
|
||||
// CHECK-NOT: padding
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
|
|
@ -423,6 +423,7 @@ void populateHLOToLHLOConversionPattern(
|
|||
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ComplexOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||
|
|
|
@ -45,6 +45,7 @@ MAP_HLO_TO_LHLO(CeilOp);
|
|||
MAP_HLO_TO_LHLO(ConstOp);
|
||||
MAP_HLO_TO_LHLO(CompareOp);
|
||||
MAP_HLO_TO_LHLO(ComplexOp);
|
||||
MAP_HLO_TO_LHLO(ConvOp);
|
||||
MAP_HLO_TO_LHLO(ConvertOp);
|
||||
MAP_HLO_TO_LHLO(CopyOp);
|
||||
MAP_HLO_TO_LHLO(CosOp);
|
||||
|
|
|
@ -192,6 +192,108 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_lhlo.convolution conversion pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
|
||||
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Check validity of dimension information.
|
||||
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
op.dimension_numbers()) {
|
||||
const int inputSpatialRank =
|
||||
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||
// The dimensions for input should follow the order of
|
||||
// batch_count, spatial_dims..., input_feature_count.
|
||||
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
|
||||
dimensionNumbers.input_feature_dimension().getInt() !=
|
||||
(inputSpatialRank + 1))
|
||||
return failure();
|
||||
|
||||
const int kernelSpatialRank =
|
||||
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
|
||||
// The dimensions for filter should follow the order of
|
||||
// spatial_dims..., input_feature_count, num_output_feature_count.
|
||||
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
|
||||
kernelSpatialRank ||
|
||||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
|
||||
(kernelSpatialRank + 1))
|
||||
return failure();
|
||||
|
||||
const int outputSpatialRank =
|
||||
llvm::size(dimensionNumbers.output_spatial_dimensions());
|
||||
// The dimensions for output should follow the order of
|
||||
// batch_count, spatial_dims.., output_feature_count.
|
||||
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
|
||||
dimensionNumbers.output_feature_dimension().getInt() !=
|
||||
(outputSpatialRank + 1))
|
||||
return failure();
|
||||
|
||||
if (inputSpatialRank != outputSpatialRank ||
|
||||
inputSpatialRank != kernelSpatialRank)
|
||||
return failure();
|
||||
|
||||
auto inputSpatialDim =
|
||||
dimensionNumbers.input_spatial_dimensions().begin();
|
||||
auto kernelSpatialDim =
|
||||
dimensionNumbers.kernel_spatial_dimensions().begin();
|
||||
auto outputSpatialDim =
|
||||
dimensionNumbers.output_spatial_dimensions().begin();
|
||||
// Check if spatial dims are ordered correctly.
|
||||
for (int i = 0; i < inputSpatialRank; ++i) {
|
||||
const int dim = i + 1;
|
||||
if ((*inputSpatialDim++).getZExtValue() != dim ||
|
||||
(*outputSpatialDim++).getZExtValue() != dim ||
|
||||
(*kernelSpatialDim++).getZExtValue() != i)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: LHS dilation for deconvolution not supported yet.
|
||||
if (op.lhs_dilation()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
llvm::SmallVector<Attribute, 4> strides;
|
||||
if (auto windowStrides = op.window_strides()) {
|
||||
auto range = windowStrides->getAttributeValues();
|
||||
strides.assign(range.begin(), range.end());
|
||||
}
|
||||
auto stridesArg = ArrayAttr::get(strides, op.getContext());
|
||||
|
||||
llvm::SmallVector<Attribute, 2> dilation;
|
||||
if (auto rhsDilation = op.rhs_dilation()) {
|
||||
auto range = rhsDilation->getAttributeValues();
|
||||
dilation.assign(range.begin(), range.end());
|
||||
} else {
|
||||
// Default dilation of 1.
|
||||
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
|
||||
}
|
||||
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
|
||||
|
||||
// Set padding only if it is non-zero.
|
||||
DenseIntElementsAttr padding = op.paddingAttr();
|
||||
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
|
||||
return !intVal.isNullValue();
|
||||
})) {
|
||||
padding = nullptr;
|
||||
}
|
||||
|
||||
// The order of input and filter are switched with linalg.conv.
|
||||
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
|
||||
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Base class for lowering xla operations that have one operand and one result,
|
||||
/// and are semantically equivalent to a copy of the input to the output (like
|
||||
/// transpose, some reshape, etc.). The derived classes need to provide a method
|
||||
|
@ -814,6 +916,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
// clang-format off
|
||||
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
|
||||
ConstConverter,
|
||||
ConvToLinalgConverter,
|
||||
IotaConverter,
|
||||
LhloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
|
||||
|
|
Loading…
Reference in New Issue