Add conversions from XLA_HLO copy, compare and select to Linalg.
Adapt the LHLO copy, compare and select conversions to linalg to support HLO conversion as well. PiperOrigin-RevId: 295232363 Change-Id: If7995ca3cf40d673d98efcbd409a16f71ac90a8d
This commit is contained in:
parent
53e07edfca
commit
454f89ab3b
tensorflow/compiler/mlir/xla
@ -17,6 +17,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: integer_add
|
||||
func @integer_add(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -28,6 +29,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_mul
|
||||
func @float_mul(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -39,6 +41,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_mul
|
||||
func @integer_mul(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -50,6 +53,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_remainder
|
||||
func @float_remainder(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -61,6 +65,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_remainder
|
||||
func @integer_remainder(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -72,6 +77,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_sub
|
||||
func @float_sub(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -83,6 +89,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_sub
|
||||
func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -94,6 +101,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_abs
|
||||
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
@ -103,6 +111,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_exp
|
||||
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
@ -112,6 +121,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_ceil
|
||||
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ceilf
|
||||
@ -121,6 +131,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_neg
|
||||
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: negf
|
||||
@ -130,6 +141,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_tanh
|
||||
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: tanh
|
||||
@ -139,6 +151,7 @@ func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_and
|
||||
func @integer_and(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
@ -147,3 +160,56 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @copy
|
||||
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
|
||||
return %0 : tensor<2x4x8xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
@ -44,6 +44,11 @@ struct ScalarOp<xla_lhlo::CompareOp> {
|
||||
using IOp = ::mlir::CmpIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_hlo::CompareOp> {
|
||||
using FOp = ::mlir::CmpFOp;
|
||||
using IOp = ::mlir::CmpIOp;
|
||||
};
|
||||
template <>
|
||||
struct ScalarOp<xla_lhlo::DivOp> {
|
||||
using FOp = ::mlir::DivFOp;
|
||||
using IOp = ::mlir::SignedDivIOp;
|
||||
@ -84,10 +89,10 @@ struct ScalarOp<xla_hlo::SubOp> {
|
||||
using IOp = ::mlir::SubIOp;
|
||||
};
|
||||
|
||||
template <typename LHLO_BinaryOp>
|
||||
using ScalarFOp = typename ScalarOp<LHLO_BinaryOp>::FOp;
|
||||
template <typename LHLO_BinaryOp>
|
||||
using ScalarIOp = typename ScalarOp<LHLO_BinaryOp>::IOp;
|
||||
template <typename XLA_BinaryOp>
|
||||
using ScalarFOp = typename ScalarOp<XLA_BinaryOp>::FOp;
|
||||
template <typename XLA_BinaryOp>
|
||||
using ScalarIOp = typename ScalarOp<XLA_BinaryOp>::IOp;
|
||||
|
||||
template <typename... Args>
|
||||
struct MapXlaOpToStdScalarOpImpl {
|
||||
@ -97,6 +102,14 @@ struct MapXlaOpToStdScalarOpImpl {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename StdScalarOp>
|
||||
struct MapXlaOpToStdScalarOpImpl<StdScalarOp> {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
||||
struct MapXlaOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||
@ -154,7 +167,8 @@ inline Value MapXlaOpToStdScalarOp<xla_hlo::AndOp>(xla_hlo::AndOp xla_op,
|
||||
xla_op.getLoc(), result_types, args, b);
|
||||
}
|
||||
|
||||
inline CmpFPredicate getFloatCmpPredicate(StringRef xla_comparison_direction) {
|
||||
inline Optional<CmpFPredicate> getFloatCmpPredicate(
|
||||
StringRef xla_comparison_direction) {
|
||||
return llvm::StringSwitch<CmpFPredicate>(xla_comparison_direction)
|
||||
.Case("EQ", CmpFPredicate::OEQ)
|
||||
.Case("NE", CmpFPredicate::ONE)
|
||||
@ -177,10 +191,10 @@ inline Optional<CmpIPredicate> getIntCmpPredicate(
|
||||
.Default(llvm::None);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
template <typename XLACompareOpTy>
|
||||
inline Value MapXlaCompareOpToStdScalarOp(XLACompareOpTy xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
Type element_type = lhs.getType();
|
||||
@ -188,16 +202,32 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
Optional<CmpIPredicate> predicate =
|
||||
getIntCmpPredicate(xla_op.comparison_direction());
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarIOp<xla_lhlo::CompareOp>>(
|
||||
xla_op.getLoc(), predicate.getValue(), lhs, rhs);
|
||||
return b->create<ScalarIOp<XLACompareOpTy>>(xla_op.getLoc(),
|
||||
predicate.getValue(), lhs, rhs);
|
||||
}
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return b->create<ScalarFOp<xla_lhlo::CompareOp>>(
|
||||
xla_op.getLoc(), getFloatCmpPredicate(xla_op.comparison_direction()),
|
||||
lhs, rhs);
|
||||
Optional<CmpFPredicate> predicate =
|
||||
getFloatCmpPredicate(xla_op.comparison_direction());
|
||||
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||
return b->create<ScalarFOp<XLACompareOpTy>>(xla_op.getLoc(),
|
||||
predicate.getValue(), lhs, rhs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
xla_lhlo::CompareOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(xla_op, result_types,
|
||||
args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CompareOp>(
|
||||
xla_hlo::CompareOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return MapXlaCompareOpToStdScalarOp<xla_hlo::CompareOp>(xla_op, result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||
@ -205,6 +235,13 @@ inline Value MapXlaOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::CopyOp>(xla_hlo::CopyOp xla_op,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::ExpOp>(xla_lhlo::ExpOp xla_op,
|
||||
@ -364,8 +401,15 @@ template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
xla_lhlo::SelectOp xla_op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
return b->create<::mlir::SelectOp>(xla_op.getLoc(), result_types, args,
|
||||
mlir::None);
|
||||
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
|
||||
result_types, args, b);
|
||||
}
|
||||
template <>
|
||||
inline Value MapXlaOpToStdScalarOp<xla_hlo::SelectOp>(
|
||||
xla_hlo::SelectOp xla_op, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapXlaOpToStdScalarOpImpl<::mlir::SelectOp>{}(xla_op.getLoc(),
|
||||
result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -410,10 +410,13 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>>(context);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user