Merge pull request #44097 from ahmedsabie:T2.7
PiperOrigin-RevId: 338413473 Change-Id: I59613a4f380073a69fa50344d3f9cb0184b78144
This commit is contained in:
commit
abfdb66c8e
@ -1361,7 +1361,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, %
|
||||
|
||||
// CHECK-LABEL: conv2d_backprop_input
|
||||
// CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32>
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST]]) : (tensor<4xi32>) -> tensor<4xi32>
|
||||
// CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CAST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32>
|
||||
// CHECK: %[[CST_0:.*]] = constant unit
|
||||
// CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
// CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2, %[[CST_0]]) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>, none) -> tensor<15x28x28x1xf32>
|
||||
@ -1587,10 +1588,17 @@ func @tranpose_int64_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
// CHECK: "tfl.transpose"
|
||||
}
|
||||
|
||||
func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> {
|
||||
func @tranpose_arg32(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf32> {
|
||||
%0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32>
|
||||
return %0 : tensor<3x2xf32>
|
||||
// CHECK-LABEL: tranpose_arg
|
||||
// CHECK-LABEL: tranpose_arg32
|
||||
// CHECK: "tfl.transpose"
|
||||
}
|
||||
|
||||
func @tranpose_arg64(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi64>) -> tensor<3x2xf32> {
|
||||
%0 = "tf.Transpose"(%arg0, %arg1): (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32>
|
||||
return %0 : tensor<3x2xf32>
|
||||
// CHECK-LABEL: tranpose_arg64
|
||||
// CHECK: "tfl.transpose"
|
||||
}
|
||||
|
||||
|
@ -54,7 +54,7 @@ def ExtractSingleElementAsInt32 : NativeCodeCall<
|
||||
"$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast<ElementsAttr>()).getInt())">;
|
||||
|
||||
// Converts tensor with int64 to int32.
|
||||
def CreateCastToInt32 : NativeCodeCall<
|
||||
def CreateTFLCastToInt32Op : NativeCodeCall<
|
||||
"CreateCastToInt32($0, $_loc, $_builder)">;
|
||||
|
||||
// Checks whether the given operation has static shapes and same shapes of all inputs.
|
||||
@ -216,12 +216,9 @@ def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims),
|
||||
(TFL_SqueezeOp $arg, $squeeze_dims)>;
|
||||
def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>;
|
||||
|
||||
def LegalizeTransposeInt64 : Pat<
|
||||
(TF_TransposeOp $arg, (ConstantOp Int64ElementsAttr:$perm)),
|
||||
(TFL_TransposeOp $arg, (CreateCastToInt32 $perm))>;
|
||||
|
||||
def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm),
|
||||
(TFL_TransposeOp $arg, $perm)>;
|
||||
(TFL_TransposeOp $arg,
|
||||
(CreateTFLCastToInt32Op $perm))>;
|
||||
|
||||
def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>;
|
||||
def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>;
|
||||
|
@ -118,14 +118,11 @@ bool HasSameStaticShapes(Operation* op) {
|
||||
}
|
||||
|
||||
// Util that casts 'val' to Int32 by adding a cast Op.
|
||||
Value CreateCastToInt32(Attribute val, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) {
|
||||
auto shape = val.getType().dyn_cast<RankedTensorType>().getShape();
|
||||
IntegerType new_ele_type = rewriter.getIntegerType(32);
|
||||
ShapedType new_type = RankedTensorType::get(shape, new_ele_type);
|
||||
return rewriter.create<TF::CastOp>(loc, new_type,
|
||||
rewriter.create<TF::ConstOp>(loc, val),
|
||||
rewriter.getBoolAttr(false));
|
||||
return rewriter.create<TFL::CastOp>(loc, new_type, val);
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_legalize_tf.inc"
|
||||
|
Loading…
x
Reference in New Issue
Block a user