Merge pull request #40064 from polymage-labs:conv

PiperOrigin-RevId: 315495471
Change-Id: I67ce7c82801a303d36bcd5583f009a1f6de948e5
This commit is contained in:
TensorFlower Gardener 2020-06-09 09:15:17 -07:00
commit 56ef3f784b
9 changed files with 203 additions and 6 deletions

View File

@ -28,6 +28,8 @@ package_group(
exports_files(["ir/hlo_ops.td"])
exports_files(["ir/lhlo_ops.td"])
filegroup(
name = "hlo_ops_td_files",
srcs = [
@ -88,6 +90,8 @@ gentbl(
tbl_outs = [
("-gen-op-decls", "ir/lhlo_ops.h.inc"),
("-gen-op-defs", "ir/lhlo_ops.cc.inc"),
("-gen-struct-attr-decls", "ir/lhlo_structs.h.inc"),
("-gen-struct-attr-defs", "ir/lhlo_structs.cc.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ir/lhlo_ops.td",
@ -399,6 +403,7 @@ cc_library(
":map_hlo_to_lhlo_op",
"@com_google_absl//absl/memory",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc"
namespace mlir {
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.cc.inc"
namespace xla_lhlo {
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)

View File

@ -33,6 +33,8 @@ limitations under the License.
namespace mlir {
class OpBuilder;
#include "tensorflow/compiler/mlir/xla/ir/lhlo_structs.h.inc"
namespace xla_lhlo {
class XlaLhloDialect : public Dialect {

View File

@ -407,11 +407,39 @@ def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
);
}
// TODO(bondhugula): Make this struct dialect independent so that it can be
// shared between the HLO and LHLO dialects.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}

View File

@ -432,7 +432,39 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]],
// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]])
// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
return %dot : tensor<1024x1024xf32>
}
%dot = "xla_hlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
return %dot : tensor<1024x1024xf32>
}
// -----
// CHECK-LABEL: func @conv
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
// CHECK-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
return %out : tensor<3x5x5x4xf32>
}

View File

@ -511,7 +511,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[RESULT:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
@ -691,3 +691,27 @@ func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
%c0 = constant 0 : index
%0 = alloc() : memref<3x5x5x4xf32>
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 2]
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: strides = [2, 1]}
// With all atributes explicitly specified.
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
// Dilation left unspecified, sets default dilation since linalg expects it.
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 1]
// Padding is not set if it's zero.
// CHECK-NOT: padding
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}

View File

@ -423,6 +423,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<xla_hlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ComplexOp>,
HloToLhloOpConverter<xla_hlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp>,

View File

@ -45,6 +45,7 @@ MAP_HLO_TO_LHLO(CeilOp);
MAP_HLO_TO_LHLO(ConstOp);
MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ComplexOp);
MAP_HLO_TO_LHLO(ConvOp);
MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp);

View File

@ -192,6 +192,108 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
}
};
//===----------------------------------------------------------------------===//
// xla_lhlo.convolution conversion pattern.
//===----------------------------------------------------------------------===//
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
public:
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) {
const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() !=
(inputSpatialRank + 1))
return failure();
const int kernelSpatialRank =
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
(kernelSpatialRank + 1))
return failure();
const int outputSpatialRank =
llvm::size(dimensionNumbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() !=
(outputSpatialRank + 1))
return failure();
if (inputSpatialRank != outputSpatialRank ||
inputSpatialRank != kernelSpatialRank)
return failure();
auto inputSpatialDim =
dimensionNumbers.input_spatial_dimensions().begin();
auto kernelSpatialDim =
dimensionNumbers.kernel_spatial_dimensions().begin();
auto outputSpatialDim =
dimensionNumbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly.
for (int i = 0; i < inputSpatialRank; ++i) {
const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim ||
(*kernelSpatialDim++).getZExtValue() != i)
return failure();
}
}
// TODO: LHS dilation for deconvolution not supported yet.
if (op.lhs_dilation()) {
return failure();
}
llvm::SmallVector<Attribute, 4> strides;
if (auto windowStrides = op.window_strides()) {
auto range = windowStrides->getAttributeValues();
strides.assign(range.begin(), range.end());
}
auto stridesArg = ArrayAttr::get(strides, op.getContext());
llvm::SmallVector<Attribute, 2> dilation;
if (auto rhsDilation = op.rhs_dilation()) {
auto range = rhsDilation->getAttributeValues();
dilation.assign(range.begin(), range.end());
} else {
// Default dilation of 1.
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
}
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
// Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr();
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
return !intVal.isNullValue();
})) {
padding = nullptr;
}
// The order of input and filter are switched with linalg.conv.
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
return success();
}
};
/// Base class for lowering xla operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method
@ -814,6 +916,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,