Add lowering from xla_hlo.sine to linalg.
This also needs adding sine operation to LHLO. PiperOrigin-RevId: 309998062 Change-Id: Ib2b59a2925d6c1ed698da1cba95e7524441f5628
This commit is contained in:
parent
9dce54f00b
commit
1f02731775
@ -102,6 +102,8 @@ def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp;
|
|||||||
|
|
||||||
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
|
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
|
||||||
|
|
||||||
|
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp;
|
||||||
|
|
||||||
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
|
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -222,6 +222,16 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @float_sin
|
||||||
|
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: sin
|
||||||
|
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
|
return %0 : tensor<2x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @copy
|
// CHECK-LABEL: func @copy
|
||||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||||
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||||
|
@ -423,6 +423,20 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @sin
|
||||||
|
func @sin(%input: memref<2x2xf32>,
|
||||||
|
%result: memref<2x2xf32>) {
|
||||||
|
"xla_lhlo.sine"(%input, %result)
|
||||||
|
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @negf
|
// CHECK-LABEL: func @negf
|
||||||
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
@ -63,6 +63,7 @@ MAP_HLO_TO_LHLO(RemOp);
|
|||||||
MAP_HLO_TO_LHLO(RsqrtOp);
|
MAP_HLO_TO_LHLO(RsqrtOp);
|
||||||
MAP_HLO_TO_LHLO(SelectOp);
|
MAP_HLO_TO_LHLO(SelectOp);
|
||||||
MAP_HLO_TO_LHLO(SignOp);
|
MAP_HLO_TO_LHLO(SignOp);
|
||||||
|
MAP_HLO_TO_LHLO(SinOp);
|
||||||
MAP_HLO_TO_LHLO(SqrtOp);
|
MAP_HLO_TO_LHLO(SqrtOp);
|
||||||
MAP_HLO_TO_LHLO(SubOp);
|
MAP_HLO_TO_LHLO(SubOp);
|
||||||
MAP_HLO_TO_LHLO(TanhOp);
|
MAP_HLO_TO_LHLO(TanhOp);
|
||||||
|
@ -275,6 +275,14 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
|
|||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
||||||
/// linalg.generic op) for compare-select style operations like min/max.
|
/// linalg.generic op) for compare-select style operations like min/max.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
|
@ -633,6 +633,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
|
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
|
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
|
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
|
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||||
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||||
@ -728,6 +729,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||||||
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||||
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
|
||||||
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user