Lower TensorFlow op Reciprocal to TensorFlow Div op

tf.Reciprocal(x) = tf.Div(1, x)

PiperOrigin-RevId: 286581834
Change-Id: I1e69f977e9dcceb599302c9b6a31665e4f75caff
This commit is contained in:
Smit Hinsu 2019-12-20 08:21:28 -08:00 committed by TensorFlower Gardener
parent 9b33150707
commit 8a001a4521
3 changed files with 17 additions and 0 deletions

View File

@ -171,6 +171,8 @@ def TF_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TF_I32Or64]>;
// Any integer or floating-point tensor types
def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>;
def TF_SintOrFpTensor : TensorOf<[TF_SInt, AnyFloat]>;
def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>;
def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex],

View File

@ -424,3 +424,10 @@ func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> {
%0 = "tf.DynamicStitch"(%indices, %arg0) : (tensor<2xi32>, tensor<2x2xf32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
%0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -183,6 +183,14 @@ def : Pat<(TF_PadOp TensorOf<[AnyInteger, AnyFloat]>:$input, $paddings),
(TF_PadV2Op $input, $paddings,
(TF_ConstOp (GetScalarOfType<0> $input)))>;
//===----------------------------------------------------------------------===//
// Reciprocal op patterns.
//===----------------------------------------------------------------------===//
// TODO(hinsu): Support complex and unsigned input types.
def LowerReciprocal : Pat<(TF_ReciprocalOp TF_SintOrFpTensor:$x),
(TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>;
//===----------------------------------------------------------------------===//
// Rsqrt op patterns.
//===----------------------------------------------------------------------===//