Lower tf.OnesLike using generic tf.BroadcastTo op

PiperOrigin-RevId: 339115562
Change-Id: Ia1151e74abe160904176058402da90852e6640ba
This commit is contained in:
Smit Hinsu 2020-10-26 14:16:54 -07:00 committed by TensorFlower Gardener
parent f3adb6a9a8
commit 422df815ba
3 changed files with 33 additions and 6 deletions

View File

@ -7708,6 +7708,20 @@ times by rerunning "MakeIterator".
);
}
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Returns a tensor of ones with the same shape and type as x.";
let arguments = (ins
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$x
);
let results = (outs
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$y
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
let summary = "Enqueue multiple Tensor values on the computation outfeed.";

View File

@ -468,6 +468,16 @@ func @ZerosLike_variant(%arg0: tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf
return %0 : tensor<!tf.variant<tensor<2xi32>>>
}
// CHECK-LABEL: func @OnesLike_unranked
func @OnesLike_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<*xi32>) -> tensor<?xi64>
// CHECK: "tf.BroadcastTo"(%[[ONE]], %[[SHAPE]]) : (tensor<i32>, tensor<?xi64>) -> tensor<*xi32>
%0 = "tf.OnesLike"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// CHECK-LABEL: func @addN_2
func @addN_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)

View File

@ -271,12 +271,15 @@ def LowerFakeQuantWithMinMaxArgs :
def CreateTFShapeOp : NativeCodeCall<
"$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;
// TODO(hinsu): Support inputs of TensorList types.
def LowerZerosLikeOp :
Pat<(TF_ZerosLikeOp:$src_op
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
class LowerInitializationOp<Op FromOp, int initial_val>
: Pat<(FromOp:$src_op
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<initial_val> $input)),
(CreateTFShapeOp $src_op, $input,
/*use 32bit*/ConstBoolAttrFalse))>;
def LowerZerosLikeOp : LowerInitializationOp<TF_ZerosLikeOp, 0>;
def LowerOnesLikeOp : LowerInitializationOp<TF_OnesLikeOp, 1>;
def LowerScatterNdOp :
Pat<(TF_ScatterNdOp $indices,