From 454f89ab3baacbac567d6bcceef4c743f23ce58b Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Fri, 14 Feb 2020 14:50:45 -0800 Subject: [PATCH] Add conversions from XLA_HLO copy, compare and select to Linalg. Adapt the LHLO copy, compare and select conversions to linalg to support HLO conversion as well. PiperOrigin-RevId: 295232363 Change-Id: If7995ca3cf40d673d98efcbd409a16f71ac90a8d --- .../xla/tests/hlo-legalize-to-linalg.mlir | 66 ++++++++++++++++ .../xla/transforms/map_xla_to_scalar_op.h | 76 +++++++++++++++---- .../xla/transforms/xla_legalize_to_linalg.cc | 3 + 3 files changed, 129 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index a0a28dcf5af..b5242d06dae 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -17,6 +17,7 @@ func @float_add(%lhs: tensor<2x2xf32>, // ----- +// CHECK-LABEL: integer_add func @integer_add(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic @@ -28,6 +29,7 @@ func @integer_add(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @float_mul func @float_mul(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic @@ -39,6 +41,7 @@ func @float_mul(%lhs: tensor<2x2xf32>, // ----- +// CHECK-LABEL: func @integer_mul func @integer_mul(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic @@ -50,6 +53,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @float_remainder func @float_remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic @@ -61,6 +65,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>, // ----- +// CHECK-LABEL: func @integer_remainder func @integer_remainder(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic @@ -72,6 +77,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @float_sub func @float_sub(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic @@ -83,6 +89,7 @@ func @float_sub(%lhs: tensor<2x2xf32>, // ----- +// CHECK-LABEL: func @integer_sub func @integer_sub(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic @@ -94,6 +101,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: func @float_abs func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: absf @@ -103,6 +111,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_exp func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: exp @@ -112,6 +121,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_ceil func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: ceilf @@ -121,6 +131,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_neg func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: negf @@ -130,6 +141,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_tanh func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic // CHECK: tanh @@ -139,6 +151,7 @@ func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @integer_and func @integer_and(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic @@ -147,3 +160,56 @@ func @integer_and(%lhs: tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } + +// ----- + +// CHECK-LABEL: func @float_cmp +func @float_cmp(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { + %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @int_cmp +func @int_cmp(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> { + %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} + : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @copy +func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { + %0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) + return %0 : tensor<2x4x8xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32 + +// ----- + +// CHECK-LABEL: func @select +func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = "xla_hlo.select"(%pred, %lhs, %rhs) + : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) + return %0 : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index 35e1be04fa1..b7b807333ba 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -44,6 +44,11 @@ struct ScalarOp { using IOp = ::mlir::CmpIOp; }; template <> +struct ScalarOp { + using FOp = ::mlir::CmpFOp; + using IOp = ::mlir::CmpIOp; +}; +template <> struct ScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; @@ -84,10 +89,10 @@ struct ScalarOp { using IOp = ::mlir::SubIOp; }; -template -using ScalarFOp = typename ScalarOp::FOp; -template -using ScalarIOp = typename ScalarOp::IOp; +template +using ScalarFOp = typename ScalarOp::FOp; +template +using ScalarIOp = typename ScalarOp::IOp; template struct MapXlaOpToStdScalarOpImpl { @@ -97,6 +102,14 @@ struct MapXlaOpToStdScalarOpImpl { } }; +template +struct MapXlaOpToStdScalarOpImpl { + Value operator()(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return b->template create(loc, result_types, args, mlir::None); + } +}; + template struct MapXlaOpToStdScalarOpImpl { Value operator()(Location loc, ArrayRef result_types, @@ -154,7 +167,8 @@ inline Value MapXlaOpToStdScalarOp(xla_hlo::AndOp xla_op, xla_op.getLoc(), result_types, args, b); } -inline CmpFPredicate getFloatCmpPredicate(StringRef xla_comparison_direction) { +inline Optional getFloatCmpPredicate( + StringRef xla_comparison_direction) { return llvm::StringSwitch(xla_comparison_direction) .Case("EQ", CmpFPredicate::OEQ) .Case("NE", CmpFPredicate::ONE) @@ -177,10 +191,10 @@ inline Optional getIntCmpPredicate( .Default(llvm::None); } -template <> -inline Value MapXlaOpToStdScalarOp( - xla_lhlo::CompareOp xla_op, ArrayRef result_types, - ArrayRef args, OpBuilder* b) { +template +inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op, + ArrayRef result_types, + ArrayRef args, OpBuilder* b) { const auto& lhs = args[0]; const auto& rhs = args[1]; Type element_type = lhs.getType(); @@ -188,16 +202,32 @@ inline Value MapXlaOpToStdScalarOp( Optional predicate = getIntCmpPredicate(xla_op.comparison_direction()); assert(predicate.hasValue() && "expected valid comparison direction"); - return b->create>( - xla_op.getLoc(), predicate.getValue(), lhs, rhs); + return b->create>(xla_op.getLoc(), + predicate.getValue(), lhs, rhs); } if (element_type.isa()) { - return b->create>( - xla_op.getLoc(), getFloatCmpPredicate(xla_op.comparison_direction()), - lhs, rhs); + Optional predicate = + getFloatCmpPredicate(xla_op.comparison_direction()); + assert(predicate.hasValue() && "expected valid comparison direction"); + return b->create>(xla_op.getLoc(), + predicate.getValue(), lhs, rhs); } return nullptr; } +template <> +inline Value MapXlaOpToStdScalarOp( + xla_lhlo::CompareOp xla_op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return MapXlaCompareOpToStdScalarOp(xla_op, result_types, + args, b); +} +template <> +inline Value MapXlaOpToStdScalarOp( + xla_hlo::CompareOp xla_op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return MapXlaCompareOpToStdScalarOp(xla_op, result_types, + args, b); +} template <> inline Value MapXlaOpToStdScalarOp( @@ -205,6 +235,13 @@ inline Value MapXlaOpToStdScalarOp( OpBuilder* b) { return args.front(); } +template <> +inline Value MapXlaOpToStdScalarOp(xla_hlo::CopyOp xla_op, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return args.front(); +} template <> inline Value MapXlaOpToStdScalarOp(xla_lhlo::ExpOp xla_op, @@ -364,8 +401,15 @@ template <> inline Value MapXlaOpToStdScalarOp( xla_lhlo::SelectOp xla_op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return b->create<::mlir::SelectOp>(xla_op.getLoc(), result_types, args, - mlir::None); + return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(), + result_types, args, b); +} +template <> +inline Value MapXlaOpToStdScalarOp( + xla_hlo::SelectOp xla_op, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(), + result_types, args, b); } template <> diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 2d9b3b4f6ed..b6019b1e263 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -410,10 +410,13 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter>(context); }