[XLA:GPU][MLIR] Lower LHLO CopyOp to linalg.generic.
PiperOrigin-RevId: 290583409 Change-Id: Ie525620f1e7c073c66573a529f38fae92a6d9fb6
This commit is contained in:
parent
46c271b15d
commit
eae45e5710
@ -102,6 +102,19 @@ func @exp(%input: memref<2x2xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @copy
|
||||
func @copy(%input: memref<2x4x8xf32>,
|
||||
%result: memref<2x4x8xf32>) {
|
||||
"xla_lhlo.copy"(%input, %result)
|
||||
: (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xi1>) {
|
||||
|
@ -56,13 +56,12 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
|
||||
}
|
||||
auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
|
||||
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
|
||||
Operation* result = MapLhloOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(op), element_type, {l, r}, rewriter);
|
||||
if (result == nullptr) {
|
||||
Value opResult = MapLhloOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(op), element_type, {l, r}, &rewriter);
|
||||
if (opResult == nullptr) {
|
||||
return this->matchFailure();
|
||||
}
|
||||
rewriter.create<StoreOp>(loc, result->getResult(0), op.out(),
|
||||
induction_vars);
|
||||
rewriter.create<StoreOp>(loc, opResult, op.out(), induction_vars);
|
||||
rewriter.eraseOp(op);
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
@ -106,9 +106,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
Operation* op = MapLhloOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(lhlo_op), bodyResultTypes, bodyArgs, rewriter);
|
||||
rewriter.create<linalg::YieldOp>(loc, op->getResults());
|
||||
Value opResult = MapLhloOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(lhlo_op), bodyResultTypes, bodyArgs, &rewriter);
|
||||
rewriter.create<linalg::YieldOp>(loc, opResult);
|
||||
rewriter.eraseOp(lhlo_op);
|
||||
return ConversionPattern::matchSuccess();
|
||||
}
|
||||
@ -133,10 +133,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||
// Create two loads from the input.
|
||||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||
Operation* op = MapLhloOpToStdScalarOp<LhloOp>(
|
||||
Value opResult = MapLhloOpToStdScalarOp<LhloOp>(
|
||||
llvm::cast<LhloOp>(lhlo_op), argType.getElementType(),
|
||||
llvm::ArrayRef<Value>{lhs, rhs}, rewriter);
|
||||
rewriter.create<StoreOp>(loc, op->getResult(0), lhlo_op.out());
|
||||
llvm::ArrayRef<Value>{lhs, rhs}, &rewriter);
|
||||
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||
rewriter.eraseOp(lhlo_op);
|
||||
return ConversionPattern::matchSuccess();
|
||||
}
|
||||
@ -322,6 +322,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
|
||||
|
@ -59,68 +59,68 @@ template <typename LHLO_BinaryOp>
|
||||
using ScalarIOp = typename ScalarOp<LHLO_BinaryOp>::IOp;
|
||||
|
||||
template <typename LhloOp>
|
||||
Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> block_args, OpBuilder b) {
|
||||
Type element_type = block_args.front().getType();
|
||||
Value MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
return b.template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
|
||||
block_args, mlir::None);
|
||||
return b->template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
|
||||
args, mlir::None);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return b.template create<ScalarFOp<LhloOp>>(lhlo_op.getLoc(), result_types,
|
||||
block_args, mlir::None);
|
||||
return b->template create<ScalarFOp<LhloOp>>(lhlo_op.getLoc(), result_types,
|
||||
args, mlir::None);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
||||
xla_lhlo::MaxOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> block_args, OpBuilder b) {
|
||||
const auto& lhs = block_args[0];
|
||||
const auto& rhs = block_args[1];
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
||||
xla_lhlo::MaxOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = lhs.getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
auto lhs_gt_rhs = b.create<ScalarIOp<CompareOp>>(
|
||||
auto lhs_gt_rhs = b->create<ScalarIOp<CompareOp>>(
|
||||
lhlo_op.getLoc(), CmpIPredicate::sgt, lhs, rhs);
|
||||
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
|
||||
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
auto lhs_gt_rhs = b.create<ScalarFOp<CompareOp>>(
|
||||
auto lhs_gt_rhs = b->create<ScalarFOp<CompareOp>>(
|
||||
lhlo_op.getLoc(), CmpFPredicate::OGT, lhs, rhs);
|
||||
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
|
||||
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
||||
xla_lhlo::MinOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> block_args, OpBuilder b) {
|
||||
const auto& lhs = block_args[0];
|
||||
const auto& rhs = block_args[1];
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
||||
xla_lhlo::MinOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = lhs.getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
auto lhs_lt_rhs = b.create<ScalarIOp<CompareOp>>(
|
||||
auto lhs_lt_rhs = b->create<ScalarIOp<CompareOp>>(
|
||||
lhlo_op.getLoc(), CmpIPredicate::slt, lhs, rhs);
|
||||
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
|
||||
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
auto lhs_lt_rhs = b.create<ScalarFOp<CompareOp>>(
|
||||
auto lhs_lt_rhs = b->create<ScalarFOp<CompareOp>>(
|
||||
lhlo_op.getLoc(), CmpFPredicate::OLT, lhs, rhs);
|
||||
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
|
||||
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> block_args, OpBuilder b) {
|
||||
Type element_type = block_args.front().getType();
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
return element_type.isa<IntegerType>()
|
||||
? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types,
|
||||
block_args, mlir::None)
|
||||
? b->create<::mlir::AndOp>(lhlo_op.getLoc(), result_types, args,
|
||||
mlir::None)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
@ -148,21 +148,21 @@ inline Optional<CmpIPredicate> getIntCmpPredicate(
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
xla_lhlo::CompareOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> block_args, OpBuilder b) {
|
||||
const auto& lhs = block_args[0];
|
||||
const auto& rhs = block_args[1];
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = lhs.getType();
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
Optional<CmpIPredicate> predicate =
|
||||
getIntCmpPredicate(lhlo_op.comparison_direction());
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b.create<ScalarIOp<CompareOp>>(lhlo_op.getLoc(),
|
||||
predicate.getValue(), lhs, rhs);
|
||||
return b->create<ScalarIOp<CompareOp>>(lhlo_op.getLoc(),
|
||||
predicate.getValue(), lhs, rhs);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return b.create<ScalarFOp<CompareOp>>(
|
||||
return b->create<ScalarFOp<CompareOp>>(
|
||||
lhlo_op.getLoc(), getFloatCmpPredicate(lhlo_op.comparison_direction()),
|
||||
lhs, rhs);
|
||||
}
|
||||
@ -170,24 +170,31 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
xla_lhlo::SelectOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> block_args, OpBuilder b) {
|
||||
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), result_types, block_args,
|
||||
mlir::None);
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), result_types, args,
|
||||
mlir::None);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> block_args, OpBuilder b) {
|
||||
Type element_type = block_args.front().getType();
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
return element_type.isa<FloatType>()
|
||||
? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types,
|
||||
block_args, mlir::None)
|
||||
? b->create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types, args,
|
||||
mlir::None)
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||
xla_lhlo::CopyOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
}
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -88,15 +88,13 @@ TEST_F(LhloGenTest, Copy) {
|
||||
CompileAndVerifyIr(R"(
|
||||
HloModule Copy
|
||||
|
||||
ENTRY %Copy (x: f32[2,4,8]) -> f32[2,4,8] {
|
||||
%x = f32[2,4,8]{1,0,2} parameter(0)
|
||||
ROOT %copy = f32[2,4,8]{2,0,1} copy(f32[2,4,8]{1,0,2} %x)
|
||||
ENTRY %Copy (x: f32[2,4]) -> f32[2,4] {
|
||||
%x = f32[2,4] parameter(0)
|
||||
ROOT %copy = f32[2,4] copy(f32[2,4] %x)
|
||||
})",
|
||||
R"(
|
||||
;CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
|
||||
;CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||
;CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4x8xf32, #[[MAP0]]>, %[[RESULT:.*]]: memref<2x4x8xf32, #[[MAP1]]>) {
|
||||
;CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4x8xf32, #[[MAP0]]>, memref<2x4x8xf32, #[[MAP1]]>) -> ()
|
||||
;CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) {
|
||||
;CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> ()
|
||||
)");
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user