Add lowering from xla_hlo.sine to linalg.
This also needs adding sine operation to LHLO. PiperOrigin-RevId: 309998062 Change-Id: Ib2b59a2925d6c1ed698da1cba95e7524441f5628
This commit is contained in:
parent
9dce54f00b
commit
1f02731775
@ -102,6 +102,8 @@ def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp;
|
||||
|
||||
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
|
||||
|
||||
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp;
|
||||
|
||||
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -222,6 +222,16 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_sin
|
||||
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: sin
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @copy
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||
|
@ -423,6 +423,20 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sine"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @negf
|
||||
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
|
@ -63,6 +63,7 @@ MAP_HLO_TO_LHLO(RemOp);
|
||||
MAP_HLO_TO_LHLO(RsqrtOp);
|
||||
MAP_HLO_TO_LHLO(SelectOp);
|
||||
MAP_HLO_TO_LHLO(SignOp);
|
||||
MAP_HLO_TO_LHLO(SinOp);
|
||||
MAP_HLO_TO_LHLO(SqrtOp);
|
||||
MAP_HLO_TO_LHLO(SubOp);
|
||||
MAP_HLO_TO_LHLO(TanhOp);
|
||||
|
@ -275,6 +275,14 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
||||
/// linalg.generic op) for compare-select style operations like min/max.
|
||||
template <typename... Args>
|
||||
|
@ -633,6 +633,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||
@ -728,6 +729,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||
|
Loading…
Reference in New Issue
Block a user