From 1f02731775ebd5b6dec2338e043007b928a8302d Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Tue, 5 May 2020 12:27:16 -0700 Subject: [PATCH] Add lowering from xla_hlo.sine to linalg. This also needs adding sine operation to LHLO. PiperOrigin-RevId: 309998062 Change-Id: Ib2b59a2925d6c1ed698da1cba95e7524441f5628 --- tensorflow/compiler/mlir/xla/ir/lhlo_ops.td | 2 ++ .../mlir/xla/tests/hlo-legalize-to-linalg.mlir | 10 ++++++++++ .../mlir/xla/tests/lhlo-legalize-to-linalg.mlir | 14 ++++++++++++++ .../mlir/xla/transforms/map_hlo_to_lhlo_op.h | 1 + .../mlir/xla/transforms/map_xla_to_scalar_op.h | 8 ++++++++ .../mlir/xla/transforms/xla_legalize_to_linalg.cc | 2 ++ 6 files changed, 37 insertions(+) diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index 6fcb2582002..3abd117f570 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -102,6 +102,8 @@ def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp; def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; +def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp; + def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir index e31079a85cd..aa949a01388 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo-legalize-to-linalg.mlir @@ -222,6 +222,16 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_sin +func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: sin + %0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir index 0cfb1582dcb..0fc30ed4901 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo-legalize-to-linalg.mlir @@ -423,6 +423,20 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @sin +func @sin(%input: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.sine"(%input, %result) + : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK-LABEL: func @negf func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h index 6178434c8bb..fed21e9bafc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -63,6 +63,7 @@ MAP_HLO_TO_LHLO(RemOp); MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(SelectOp); MAP_HLO_TO_LHLO(SignOp); +MAP_HLO_TO_LHLO(SinOp); MAP_HLO_TO_LHLO(SqrtOp); MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(TanhOp); diff --git a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h index 8296011bf54..dceb73efb33 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_xla_to_scalar_op.h @@ -275,6 +275,14 @@ inline Value MapLhloOpToStdScalarOp( loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + /// Implements the conversion of XLA op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index 48acd0b2d47..9cce6799288 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -633,6 +633,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -728,6 +729,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter,