Add lowering from xla_hlo/lhlo reverse op to Linalg.
This is only supported for static shape. PiperOrigin-RevId: 312173157 Change-Id: Iab149f02153597ef5a967628397fcac9a4db1329
This commit is contained in:
parent
1acf6989bf
commit
ad6e816328
@ -542,3 +542,16 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%result = "xla_hlo.reverse"(%input) {
|
||||
dimensions = dense<1> : tensor<1xi64>
|
||||
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %result : tensor<2x3xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
@ -636,3 +636,16 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
|
||||
"xla_lhlo.reverse"(%arg0, %arg1) {
|
||||
dimensions = dense<1> : tensor<1xi64>
|
||||
} : (memref<2x3xf32>, memref<2x3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
@ -573,6 +573,34 @@ class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(b/156787842): Support the lowering for dynamic shapes.
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class ReverseConverter
|
||||
: public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO> {
|
||||
public:
|
||||
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
inputExprs.reserve(nloops);
|
||||
for (int i = 0; i < nloops; ++i)
|
||||
inputExprs.push_back(b->getAffineDimExpr(i));
|
||||
for (auto dim : op.dimensions()) {
|
||||
int i = dim.getZExtValue();
|
||||
if (resultType.isDynamicDim(i)) return {};
|
||||
int n = resultType.getShape()[i];
|
||||
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
}
|
||||
};
|
||||
|
||||
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
|
||||
@ -642,6 +670,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||
ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>,
|
||||
ReverseConverter<xla_lhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
|
||||
SliceConverter
|
||||
>(context);
|
||||
@ -742,6 +771,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReverseConverter<xla_hlo::ReverseOp, false>,
|
||||
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user