[XLA:GPU][MLIR] Lower LHLO CopyOp to linalg.generic.

PiperOrigin-RevId: 290583409
Change-Id: Ie525620f1e7c073c66573a529f38fae92a6d9fb6
This commit is contained in:
Alexander Belyaev 2020-01-20 02:30:44 -08:00 committed by TensorFlower Gardener
parent 46c271b15d
commit eae45e5710
5 changed files with 84 additions and 66 deletions

View File

@ -102,6 +102,19 @@ func @exp(%input: memref<2x2xf32>,
// -----
// CHECK-LABEL: func @copy
func @copy(%input: memref<2x4x8xf32>,
%result: memref<2x4x8xf32>) {
"xla_lhlo.copy"(%input, %result)
: (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32
// -----
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xi1>) {

View File

@ -56,13 +56,12 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOp> {
}
auto l = rewriter.create<LoadOp>(loc, lhs, induction_vars);
auto r = rewriter.create<LoadOp>(loc, rhs, induction_vars);
Operation* result = MapLhloOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(op), element_type, {l, r}, rewriter);
if (result == nullptr) {
Value opResult = MapLhloOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(op), element_type, {l, r}, &rewriter);
if (opResult == nullptr) {
return this->matchFailure();
}
rewriter.create<StoreOp>(loc, result->getResult(0), op.out(),
induction_vars);
rewriter.create<StoreOp>(loc, opResult, op.out(), induction_vars);
rewriter.eraseOp(op);
return this->matchSuccess();
}

View File

@ -106,9 +106,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
}
rewriter.setInsertionPointToEnd(block);
Operation* op = MapLhloOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(lhlo_op), bodyResultTypes, bodyArgs, rewriter);
rewriter.create<linalg::YieldOp>(loc, op->getResults());
Value opResult = MapLhloOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(lhlo_op), bodyResultTypes, bodyArgs, &rewriter);
rewriter.create<linalg::YieldOp>(loc, opResult);
rewriter.eraseOp(lhlo_op);
return ConversionPattern::matchSuccess();
}
@ -133,10 +133,10 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
// Create two loads from the input.
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
Operation* op = MapLhloOpToStdScalarOp<LhloOp>(
Value opResult = MapLhloOpToStdScalarOp<LhloOp>(
llvm::cast<LhloOp>(lhlo_op), argType.getElementType(),
llvm::ArrayRef<Value>{lhs, rhs}, rewriter);
rewriter.create<StoreOp>(loc, op->getResult(0), lhlo_op.out());
llvm::ArrayRef<Value>{lhs, rhs}, &rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
rewriter.eraseOp(lhlo_op);
return ConversionPattern::matchSuccess();
}
@ -322,6 +322,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,

View File

@ -59,68 +59,68 @@ template <typename LHLO_BinaryOp>
using ScalarIOp = typename ScalarOp<LHLO_BinaryOp>::IOp;
template <typename LhloOp>
Operation* MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front().getType();
Value MapLhloOpToStdScalarOp(LhloOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<IntegerType>()) {
return b.template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
block_args, mlir::None);
return b->template create<ScalarIOp<LhloOp>>(lhlo_op.getLoc(), result_types,
args, mlir::None);
}
if (element_type.isa<FloatType>()) {
return b.template create<ScalarFOp<LhloOp>>(lhlo_op.getLoc(), result_types,
block_args, mlir::None);
return b->template create<ScalarFOp<LhloOp>>(lhlo_op.getLoc(), result_types,
args, mlir::None);
}
return nullptr;
}
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
xla_lhlo::MaxOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0];
const auto& rhs = block_args[1];
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
xla_lhlo::MaxOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
const auto& lhs = args[0];
const auto& rhs = args[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
auto lhs_gt_rhs = b.create<ScalarIOp<CompareOp>>(
auto lhs_gt_rhs = b->create<ScalarIOp<CompareOp>>(
lhlo_op.getLoc(), CmpIPredicate::sgt, lhs, rhs);
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
}
if (element_type.isa<FloatType>()) {
auto lhs_gt_rhs = b.create<ScalarFOp<CompareOp>>(
auto lhs_gt_rhs = b->create<ScalarFOp<CompareOp>>(
lhlo_op.getLoc(), CmpFPredicate::OGT, lhs, rhs);
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_gt_rhs, lhs, rhs);
}
return nullptr;
}
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
xla_lhlo::MinOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0];
const auto& rhs = block_args[1];
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
xla_lhlo::MinOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
const auto& lhs = args[0];
const auto& rhs = args[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
auto lhs_lt_rhs = b.create<ScalarIOp<CompareOp>>(
auto lhs_lt_rhs = b->create<ScalarIOp<CompareOp>>(
lhlo_op.getLoc(), CmpIPredicate::slt, lhs, rhs);
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
}
if (element_type.isa<FloatType>()) {
auto lhs_lt_rhs = b.create<ScalarFOp<CompareOp>>(
auto lhs_lt_rhs = b->create<ScalarFOp<CompareOp>>(
lhlo_op.getLoc(), CmpFPredicate::OLT, lhs, rhs);
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), lhs_lt_rhs, lhs, rhs);
}
return nullptr;
}
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front().getType();
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
xla_lhlo::AndOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
return element_type.isa<IntegerType>()
? b.create<::mlir::AndOp>(lhlo_op.getLoc(), result_types,
block_args, mlir::None)
? b->create<::mlir::AndOp>(lhlo_op.getLoc(), result_types, args,
mlir::None)
: nullptr;
}
@ -148,21 +148,21 @@ inline Optional<CmpIPredicate> getIntCmpPredicate(
}
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
xla_lhlo::CompareOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) {
const auto& lhs = block_args[0];
const auto& rhs = block_args[1];
ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0];
const auto& rhs = args[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
Optional<CmpIPredicate> predicate =
getIntCmpPredicate(lhlo_op.comparison_direction());
assert(predicate.hasValue() && "expected valid comparison direction");
return b.create<ScalarIOp<CompareOp>>(lhlo_op.getLoc(),
predicate.getValue(), lhs, rhs);
return b->create<ScalarIOp<CompareOp>>(lhlo_op.getLoc(),
predicate.getValue(), lhs, rhs);
}
if (element_type.isa<FloatType>()) {
return b.create<ScalarFOp<CompareOp>>(
return b->create<ScalarFOp<CompareOp>>(
lhlo_op.getLoc(), getFloatCmpPredicate(lhlo_op.comparison_direction()),
lhs, rhs);
}
@ -170,24 +170,31 @@ inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::CompareOp>(
}
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
xla_lhlo::SelectOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) {
return b.create<::mlir::SelectOp>(lhlo_op.getLoc(), result_types, block_args,
mlir::None);
ArrayRef<Value> args, OpBuilder* b) {
return b->create<::mlir::SelectOp>(lhlo_op.getLoc(), result_types, args,
mlir::None);
}
template <>
inline Operation* MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types,
ArrayRef<Value> block_args, OpBuilder b) {
Type element_type = block_args.front().getType();
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
xla_lhlo::ExpOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
return element_type.isa<FloatType>()
? b.create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types,
block_args, mlir::None)
? b->create<::mlir::ExpOp>(lhlo_op.getLoc(), result_types, args,
mlir::None)
: nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
xla_lhlo::CopyOp lhlo_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
} // namespace xla_lhlo
} // namespace mlir

View File

@ -88,15 +88,13 @@ TEST_F(LhloGenTest, Copy) {
CompileAndVerifyIr(R"(
HloModule Copy
ENTRY %Copy (x: f32[2,4,8]) -> f32[2,4,8] {
%x = f32[2,4,8]{1,0,2} parameter(0)
ROOT %copy = f32[2,4,8]{2,0,1} copy(f32[2,4,8]{1,0,2} %x)
ENTRY %Copy (x: f32[2,4]) -> f32[2,4] {
%x = f32[2,4] parameter(0)
ROOT %copy = f32[2,4] copy(f32[2,4] %x)
})",
R"(
;CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
;CHECK: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
;CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4x8xf32, #[[MAP0]]>, %[[RESULT:.*]]: memref<2x4x8xf32, #[[MAP1]]>) {
;CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4x8xf32, #[[MAP0]]>, memref<2x4x8xf32, #[[MAP1]]>) -> ()
;CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) {
;CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> ()
)");
}