Lower tf.OnesLike using generic tf.BroadcastTo op
PiperOrigin-RevId: 339115562 Change-Id: Ia1151e74abe160904176058402da90852e6640ba
This commit is contained in:
parent
f3adb6a9a8
commit
422df815ba
@ -7708,6 +7708,20 @@ times by rerunning "MakeIterator".
|
||||
);
|
||||
}
|
||||
|
||||
def TF_OnesLikeOp : TF_Op<"OnesLike", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Returns a tensor of ones with the same shape and type as x.";
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$x
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Uint16, TF_Uint8]>:$y
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
|
||||
let summary = "Enqueue multiple Tensor values on the computation outfeed.";
|
||||
|
||||
|
@ -468,6 +468,16 @@ func @ZerosLike_variant(%arg0: tensor<!tf.variant<tensor<2xi32>>>) -> tensor<!tf
|
||||
return %0 : tensor<!tf.variant<tensor<2xi32>>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @OnesLike_unranked
|
||||
func @OnesLike_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> {
|
||||
// CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<*xi32>) -> tensor<?xi64>
|
||||
// CHECK: "tf.BroadcastTo"(%[[ONE]], %[[SHAPE]]) : (tensor<i32>, tensor<?xi64>) -> tensor<*xi32>
|
||||
|
||||
%0 = "tf.OnesLike"(%arg0) : (tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @addN_2
|
||||
func @addN_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1)
|
||||
|
@ -271,12 +271,15 @@ def LowerFakeQuantWithMinMaxArgs :
|
||||
def CreateTFShapeOp : NativeCodeCall<
|
||||
"$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;
|
||||
|
||||
// TODO(hinsu): Support inputs of TensorList types.
|
||||
def LowerZerosLikeOp :
|
||||
Pat<(TF_ZerosLikeOp:$src_op
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
|
||||
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
|
||||
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
|
||||
class LowerInitializationOp<Op FromOp, int initial_val>
|
||||
: Pat<(FromOp:$src_op
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
|
||||
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<initial_val> $input)),
|
||||
(CreateTFShapeOp $src_op, $input,
|
||||
/*use 32bit*/ConstBoolAttrFalse))>;
|
||||
|
||||
def LowerZerosLikeOp : LowerInitializationOp<TF_ZerosLikeOp, 0>;
|
||||
def LowerOnesLikeOp : LowerInitializationOp<TF_OnesLikeOp, 1>;
|
||||
|
||||
def LowerScatterNdOp :
|
||||
Pat<(TF_ScatterNdOp $indices,
|
||||
|
Loading…
x
Reference in New Issue
Block a user