Add conversions from XLA_HLO copy, compare and select to Linalg.

Adapt the LHLO copy, compare and select conversions to linalg to
support HLO conversion as well.

PiperOrigin-RevId: 295232363
Change-Id: If7995ca3cf40d673d98efcbd409a16f71ac90a8d
This commit is contained in:
Mahesh Ravishankar 2020-02-14 14:50:45 -08:00 committed by TensorFlower Gardener
parent 53e07edfca
commit 454f89ab3b
3 changed files with 129 additions and 16 deletions

View File

@ -17,6 +17,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
// -----
// CHECK-LABEL: integer_add
func @integer_add(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
@ -28,6 +29,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @float_mul
func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
@ -39,6 +41,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
// -----
// CHECK-LABEL: func @integer_mul
func @integer_mul(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
@ -50,6 +53,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @float_remainder
func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
@ -61,6 +65,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
// -----
// CHECK-LABEL: func @integer_remainder
func @integer_remainder(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
@ -72,6 +77,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @float_sub
func @float_sub(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
@ -83,6 +89,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
// -----
// CHECK-LABEL: func @integer_sub
func @integer_sub(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
@ -94,6 +101,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: func @float_abs
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: absf
@ -103,6 +111,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @float_exp
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: exp
@ -112,6 +121,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @float_ceil
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: ceilf
@ -121,6 +131,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @float_neg
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: negf
@ -130,6 +141,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @float_tanh
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: tanh
@ -139,6 +151,7 @@ func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @integer_and
func @integer_and(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
@ -147,3 +160,56 @@ func @integer_and(%lhs: tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @copy
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
return %0 : tensor<2x4x8xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32
// -----
// CHECK-LABEL: func @select
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32

View File

@ -44,6 +44,11 @@ struct ScalarOp<xla_lhlo::CompareOp> {
using IOp = ::mlir::CmpIOp;
};
template <>
struct ScalarOp<xla_hlo::CompareOp> {
using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp;
};
template <>
struct ScalarOp<xla_lhlo::DivOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
@ -84,10 +89,10 @@ struct ScalarOp<xla_hlo::SubOp> {
using IOp = ::mlir::SubIOp;
};
template <typename LHLO_BinaryOp>
using ScalarFOp = typename ScalarOp<LHLO_BinaryOp>::FOp;
template <typename LHLO_BinaryOp>
using ScalarIOp = typename ScalarOp<LHLO_BinaryOp>::IOp;
template <typename XLA_BinaryOp>
using ScalarFOp = typename ScalarOp<XLA_BinaryOp>::FOp;
template <typename XLA_BinaryOp>
using ScalarIOp = typename ScalarOp<XLA_BinaryOp>::IOp;
template <typename... Args>
struct MapXlaOpToStdScalarOpImpl {
@ -97,6 +102,14 @@ struct MapXlaOpToStdScalarOpImpl {
}
};
template <typename StdScalarOp>
struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
}
};
template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types,
@ -154,7 +167,8 @@ inline Value MapXlaOpToStdScalarOp<xla_hlo::AndOp>(xla_hlo::AndOp xla_op,
xla_op.getLoc(), result_types, args, b);
}
inline CmpFPredicate getFloatCmpPredicate(StringRef xla_comparison_direction) {
inline Optional<CmpFPredicate> getFloatCmpPredicate(
StringRef xla_comparison_direction) {
return llvm::StringSwitch<CmpFPredicate>(xla_comparison_direction)
.Case("EQ", CmpFPredicate::OEQ)
.Case("NE", CmpFPredicate::ONE)
@ -177,10 +191,10 @@ inline Optional<CmpIPredicate> getIntCmpPredicate(
.Default(llvm::None);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
template <typename XLACompareOpTy>
inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0];
const auto& rhs = args[1];
Type element_type = lhs.getType();
@ -188,16 +202,32 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
Optional<CmpIPredicate> predicate =
getIntCmpPredicate(xla_op.comparison_direction());
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarIOp<xla_lhlo::CompareOp>>(
xla_op.getLoc(), predicate.getValue(), lhs, rhs);
return b->create<ScalarIOp<XLACompareOpTy>>(xla_op.getLoc(),
predicate.getValue(), lhs, rhs);
}
if (element_type.isa<FloatType>()) {
return b->create<ScalarFOp<xla_lhlo::CompareOp>>(
xla_op.getLoc(), getFloatCmpPredicate(xla_op.comparison_direction()),
lhs, rhs);
Optional<CmpFPredicate> predicate =
getFloatCmpPredicate(xla_op.comparison_direction());
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarFOp<XLACompareOpTy>>(xla_op.getLoc(),
predicate.getValue(), lhs, rhs);
}
return nullptr;
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(xla_op, result_types,
args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CompareOp>(
xla_hlo::CompareOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapXlaCompareOpToStdScalarOp<xla_hlo::CompareOp>(xla_op, result_types,
args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
@ -205,6 +235,13 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
OpBuilder* b) {
return args.front();
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::CopyOp>(xla_hlo::CopyOp xla_op,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ExpOp>(xla_lhlo::ExpOp xla_op,
@ -364,8 +401,15 @@ template <>
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SelectOp>(
xla_lhlo::SelectOp xla_op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return b->create<::mlir::SelectOp>(xla_op.getLoc(), result_types, args,
mlir::None);
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
result_types, args, b);
}
template <>
inline Value MapXlaOpToStdScalarOp<xla_hlo::SelectOp>(
xla_hlo::SelectOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
result_types, args, b);
}
template <>

View File

@ -410,10 +410,13 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>>(context);
}