Merge pull request #46351 from WindQAQ:lower-size-to-prod-shape
PiperOrigin-RevId: 351443356 Change-Id: Ibadabcc9646e2eb1bf55e624470215c7bfd3e89a
This commit is contained in:
commit
b84e409c44
@ -949,3 +949,23 @@ func @xlogy(%lhs: tensor<*xf32>, %rhs: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return %[[RESULT]]
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: size_to_prod_shape_i32
|
||||
func @size_to_prod_shape_i32(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i32> {
|
||||
%0 = "tf.Size"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<4xi32>
|
||||
// CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK: return %[[PROD]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: size_to_prod_shape_i64
|
||||
func @size_to_prod_shape_i64(%arg0 : tensor<1x?x2x3xf32>) -> tensor<i64> {
|
||||
%0 = "tf.Size"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
// CHECK: %[[CONSTANT:.*]] = "tf.Const"() {value = dense<0> : tensor<i64>} : () -> tensor<i64>
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%arg0) : (tensor<1x?x2x3xf32>) -> tensor<4xi64>
|
||||
// CHECK: %[[PROD:.*]] = "tf.Prod"(%[[SHAPE]], %[[CONSTANT]]) {keep_dims = false} : (tensor<4xi64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK: return %[[PROD]]
|
||||
}
|
||||
|
@ -37,6 +37,12 @@ class GetI64ScalarElementsAttr<int value> :
|
||||
|
||||
def TrueBoolAttr : AttrConstraint<CPred<"$_self.getValue()">>;
|
||||
|
||||
def CreateTFShapeOp : NativeCodeCall<
|
||||
"$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;
|
||||
|
||||
def IsI32 : NativeCodeCall<
|
||||
"$_builder.getBoolAttr(getElementTypeOrSelf($0.getType()).isInteger(32))">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BiasAddGrad op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -253,6 +259,18 @@ def : Pat<(TF_RsqrtGradOp $lhs, $rhs),
|
||||
(TF_ConstOp (GetScalarOfType<-2> $rhs))))>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Size op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Size(x) = Prod(Shape(x), reduction_indices=0, keep_dims=false)
|
||||
def : Pat<(TF_SizeOp:$res $arg),
|
||||
(TF_ProdOp
|
||||
(CreateTFShapeOp $res, $arg, (IsI32 $res)),
|
||||
/*reduction_indices=*/(TF_ConstOp (GetScalarOfType<0> $res)),
|
||||
/*keep_dims=*/ConstBoolAttrFalse)>;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TanhGrad op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -279,8 +297,6 @@ def LowerFakeQuantWithMinMaxArgs :
|
||||
// ZerosLike op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CreateTFShapeOp : NativeCodeCall<
|
||||
"$_builder.create<TF::ShapeOp>($0.getLoc(), $1, $2)">;
|
||||
|
||||
class LowerInitializationOp<Op FromOp, int initial_val>
|
||||
: Pat<(FromOp:$src_op
|
||||
|
@ -3816,76 +3816,6 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
|
||||
return %result : tensor<10xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.Size legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: @size_scalar_i32
|
||||
func @size_scalar_i32(%input: tensor<f32>) -> (tensor<i32>) {
|
||||
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
|
||||
// CHECK-SAME: tensor<i32>
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast %[[CONST]] : tensor<i32> to tensor<i32>
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK: return %[[CAST]]
|
||||
return %size : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @size_scalar_i64
|
||||
func @size_scalar_i64(%input: tensor<f32>) -> (tensor<i64>) {
|
||||
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
|
||||
// CHECK-SAME: tensor<i64>
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast %[[CONST]] : tensor<i64> to tensor<i64>
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor<f32>) -> tensor<i64>
|
||||
// CHECK: return %[[CAST]]
|
||||
return %size : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @size_rank_one_i64
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: tensor<?xf32>)
|
||||
func @size_rank_one_i64(%input: tensor<?xf32>) -> (tensor<i64>) {
|
||||
// CHECK: %[[INIT:.*]] = mhlo.constant dense<1>
|
||||
// CHECK-SAME: tensor<i64>
|
||||
|
||||
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 0
|
||||
// CHECK-SAME: tensor<i32>
|
||||
|
||||
// CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply %[[INIT]], %[[CAST_DIM_0]]
|
||||
|
||||
%size = "tf.Size"(%input) : (tensor<?xf32>) -> tensor<i64>
|
||||
// CHECK: return %[[RESULT]]
|
||||
return %size : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @size_ranked
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: tensor<2x?x8xf32>)
|
||||
func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor<i32>) {
|
||||
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
|
||||
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 0
|
||||
// CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[CAST_DIM_0]]
|
||||
// CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 1
|
||||
// CHECK: %[[CAST_DIM_1:.*]] = "mhlo.convert"(%[[DIM_1]]) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[CAST_DIM_1]]
|
||||
// CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 2
|
||||
// CHECK: %[[CAST_DIM_2:.*]] = "mhlo.convert"(%[[DIM_2]]) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[CAST_DIM_2]]
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor<i32>
|
||||
// CHECK: return %[[MUL_2]]
|
||||
return %size : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @size_unranked
|
||||
func @size_unranked(%input: tensor<*xf32>) -> (tensor<i32>) {
|
||||
// CHECK: tf.Size
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<i32>
|
||||
return %size : tensor<i32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.Split legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2751,59 +2751,6 @@ class ConvertSoftmaxOp : public OpRewritePattern<OpTy> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts Size to HLO ops, computing the size of a ranked input tensor.
|
||||
// TODO(b/145253252): Update this to not require ranked input tensor shapes.
|
||||
//
|
||||
// The main logic of this pattern is to calculate the size by multiplying every
|
||||
// dimension of the input tensor's shape together.
|
||||
//
|
||||
// For example, the following source IR:
|
||||
//
|
||||
// %size = "tf.Size"(%input) : (tensor<2x?x8xf32>) -> tensor<i32>
|
||||
//
|
||||
// will be converted into:
|
||||
//
|
||||
// %const = mhlo.constant dense<1> : tensor<i32>
|
||||
// %dim_0 = "mhlo.get_dimension_size"(%input) {dimension = 0 : i64} :
|
||||
// (tensor<2x?x8xf32>) -> tensor<i32>
|
||||
// %prod_0 = mhlo.multiply %const, %dim_0 : tensor<i32>
|
||||
// %dim_1 = "mhlo.get_dimension_size"(%input) {dimension = 1 : i64} :
|
||||
// (tensor<2x?x8xf32>) -> tensor<i32>
|
||||
// %prod_1 = mhlo.multiply %prod_0, %dim_1 : tensor<i32>
|
||||
// %dim_2 = "mhlo.get_dimension_size"(%input) {dimension = 2 : i64} :
|
||||
// (tensor<2x?x8xf32>) -> tensor<i32>
|
||||
// %size = mhlo.multiply %prod_1, %dim_2 : tensor<i32>
|
||||
class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::SizeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||
if (!input_ty) return failure();
|
||||
|
||||
const int64_t rank = input_ty.getRank();
|
||||
auto result_ty = op.getResult().getType();
|
||||
auto element_ty = result_ty.cast<TensorType>().getElementType();
|
||||
Value size = GetScalarConstOfType(element_ty, op.getLoc(), 1, &rewriter);
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
auto i32_ty = rewriter.getIntegerType(32);
|
||||
auto size_ty = RankedTensorType::get({}, i32_ty);
|
||||
auto dim_index = rewriter.getI64IntegerAttr(i);
|
||||
Value dim = rewriter.create<GetDimensionSizeOp>(op.getLoc(), size_ty,
|
||||
input, dim_index);
|
||||
dim = rewriter.create<mhlo::ConvertOp>(op.getLoc(), result_ty, dim);
|
||||
size = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), size, dim,
|
||||
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||
}
|
||||
rewriter.replaceOp(op, size);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
|
||||
Value *out_lhs, Value *out_rhs,
|
||||
PatternRewriter *rewriter) {
|
||||
@ -6228,7 +6175,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context,
|
||||
ConvertMeanOp, ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp,
|
||||
ConvertProdOp, ConvertQrOp, ConvertDynamicRangeOp,
|
||||
ConvertMatrixDiagPartV3Op, ConvertRangeOp, ConvertSelectV2Op,
|
||||
ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp,
|
||||
ConvertSigmoidOp, ConvertShapeOp,
|
||||
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
|
||||
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
|
||||
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
|
||||
|
Loading…
Reference in New Issue
Block a user