diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 1b7d9c47c4c..3d19ee1a3da 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -537,6 +537,27 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } +// CHECK-LABEL: func @rsqrt +func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: func @rsqrt_dynamic +func @rsqrt_dynamic(%arg0: tensor) -> tensor { + // CHECK: "tf.Rsqrt"(%arg0) : (tensor) -> tensor + %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @rsqrt_rankless +func @rsqrt_rankless(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "tf.Rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @tanh func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index ade62bf84cc..2445d3a2eba 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -191,9 +191,10 @@ foreach Mapping = [ [TF_FloorOp, HLO_FloorOp], [TF_LogOp, HLO_LogOp], [TF_NegOp, HLO_NegOp], + [TF_RsqrtOp, HLO_RsqrtOp], [TF_TanhOp, HLO_TanhOp], ] in { - def : Pat<(Mapping[0] AnyTensor:$input), + def : Pat<(Mapping[0] HLO_Tensor:$input), (Mapping[1] $input)>; }