Add lowering from xla_hlo/lhlo reverse op to Linalg.

This is only supported for static shape.

PiperOrigin-RevId: 312173157
Change-Id: Iab149f02153597ef5a967628397fcac9a4db1329
This commit is contained in:
Hanhan Wang 2020-05-18 16:06:46 -07:00 committed by TensorFlower Gardener
parent 1acf6989bf
commit ad6e816328
3 changed files with 56 additions and 0 deletions

View File

@ -542,3 +542,16 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
%result = "xla_hlo.reverse"(%input) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %result : tensor<2x3xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]

View File

@ -636,3 +636,16 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
"xla_lhlo.reverse"(%arg0, %arg1) {
dimensions = dense<1> : tensor<1xi64>
} : (memref<2x3xf32>, memref<2x3xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]

View File

@ -573,6 +573,34 @@ class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
}
};
// TODO(b/156787842): Support the lowering for dynamic shapes.
template <typename OpTy, bool isLHLO = true>
class ReverseConverter
: public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO> {
public:
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.reserve(nloops);
for (int i = 0; i < nloops; ++i)
inputExprs.push_back(b->getAffineDimExpr(i));
for (auto dim : op.dimensions()) {
int i = dim.getZExtValue();
if (resultType.isDynamicDim(i)) return {};
int n = resultType.getShape()[i];
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
public:
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
@ -642,6 +670,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>,
ReverseConverter<xla_lhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
SliceConverter
>(context);
@ -742,6 +771,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
ReverseConverter<xla_hlo::ReverseOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
}