From 422df815bacd1d9c7f9edf0db905b87ffade39a7 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 26 Oct 2020 14:16:54 -0700 Subject: [PATCH] Lower tf.OnesLike using generic tf.BroadcastTo op PiperOrigin-RevId: 339115562 Change-Id: Ia1151e74abe160904176058402da90852e6640ba --- .../mlir/tensorflow/ir/tf_generated_ops.td | 14 ++++++++++++++ .../compiler/mlir/tensorflow/tests/lower_tf.mlir | 10 ++++++++++ .../mlir/tensorflow/transforms/lower_tf.td | 15 +++++++++------ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 5fa1d2ecbfd..235e2672ccb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -7708,6 +7708,20 @@ times by rerunning "MakeIterator". ); } +def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns a tensor of ones with the same shape and type as x."; + + let arguments = (ins + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$x + ); + + let results = (outs + TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> { let summary = "Enqueue multiple Tensor values on the computation outfeed."; diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index fcd2f2512fd..0ea1c671665 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -468,6 +468,16 @@ func @ZerosLike_variant(%arg0: tensor>>) -> tensor>> } +// CHECK-LABEL: func @OnesLike_unranked +func @OnesLike_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<*xi32>) -> tensor + // CHECK: "tf.BroadcastTo"(%[[ONE]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xi32> + + %0 = "tf.OnesLike"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + // CHECK-LABEL: func @addN_2 func @addN_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index fec4c20e98d..97c51659100 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -271,12 +271,15 @@ def LowerFakeQuantWithMinMaxArgs : def CreateTFShapeOp : NativeCodeCall< "$_builder.create($0.getLoc(), $1, $2)">; -// TODO(hinsu): Support inputs of TensorList types. -def LowerZerosLikeOp : - Pat<(TF_ZerosLikeOp:$src_op - TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input), - (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)), - (CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>; +class LowerInitializationOp + : Pat<(FromOp:$src_op + TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input), + (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType $input)), + (CreateTFShapeOp $src_op, $input, + /*use 32bit*/ConstBoolAttrFalse))>; + +def LowerZerosLikeOp : LowerInitializationOp; +def LowerOnesLikeOp : LowerInitializationOp; def LowerScatterNdOp : Pat<(TF_ScatterNdOp $indices,