From 8a001a45216f02a7f760f1f667aabc45ef2b0a83 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 20 Dec 2019 08:21:28 -0800 Subject: [PATCH] Lower TensorFlow op Reciprocal to TensorFlow Div op tf.Reciprocal(x) = tf.Div(1, x) PiperOrigin-RevId: 286581834 Change-Id: I1e69f977e9dcceb599302c9b6a31665e4f75caff --- tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td | 2 ++ tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir | 7 +++++++ .../compiler/mlir/tensorflow/transforms/lower_tf.td | 8 ++++++++ 3 files changed, 17 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index c3a51613357..a301c976725 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -171,6 +171,8 @@ def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>; // Any integer or floating-point tensor types def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>; +def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>; + def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>; def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 7f3a4c266cb..c1c5f419ca9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -424,3 +424,10 @@ func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> { %0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<1x2xf32> return %0 : tensor<1x2xf32> } + +func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor, tensor<*xf32>) -> tensor<*xf32> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 5dc173a34b9..07792d57a6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -183,6 +183,14 @@ def : Pat<(TF_PadOp TensorOf<[AnyInteger, AnyFloat]>:$input, $paddings), (TF_PadV2Op $input, $paddings, (TF_ConstOp (GetScalarOfType<0> $input)))>; +//===----------------------------------------------------------------------===// +// Reciprocal op patterns. +//===----------------------------------------------------------------------===// + +// TODO(hinsu): Support complex and unsigned input types. +def LowerReciprocal : Pat<(TF_ReciprocalOp TF_SintOrFpTensor:$x), + (TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>; + //===----------------------------------------------------------------------===// // Rsqrt op patterns. //===----------------------------------------------------------------------===//