Add lowering from xla_hlo.sine to linalg.

This also needs adding sine operation to LHLO.

PiperOrigin-RevId: 309998062
Change-Id: Ib2b59a2925d6c1ed698da1cba95e7524441f5628
This commit is contained in:
Mahesh Ravishankar 2020-05-05 12:27:16 -07:00 committed by TensorFlower Gardener
parent 9dce54f00b
commit 1f02731775
6 changed files with 37 additions and 0 deletions

View File

@ -102,6 +102,8 @@ def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp;
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp;
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===//

View File

@ -222,6 +222,16 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @float_sin
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: sin
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @copy
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {

View File

@ -423,6 +423,20 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @negf
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()

View File

@ -63,6 +63,7 @@ MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(RsqrtOp);
MAP_HLO_TO_LHLO(SelectOp);
MAP_HLO_TO_LHLO(SignOp);
MAP_HLO_TO_LHLO(SinOp);
MAP_HLO_TO_LHLO(SqrtOp);
MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp);

View File

@ -275,6 +275,14 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
loc, result_types, args, b);
}
/// Implements the conversion of XLA op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args>

View File

@ -633,6 +633,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
@ -728,6 +729,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,