Fix convolution_2d_transpose_bias without bias

PiperOrigin-RevId: 308272528
Change-Id: I88547df361b1c0dac782ebdeb4c483b1e3ab8c27
This commit is contained in:
Logan Chien 2020-04-24 09:53:48 -07:00 committed by TensorFlower Gardener
parent 4ce3b6c63b
commit 3184d08434
2 changed files with 17 additions and 1 deletions

View File

@ -237,6 +237,14 @@ class TFL_TFTypesWithSameBits<int i, int j, int num> :
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">, Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
PredOpTrait<"operand " # n # " is at most " # m # "-D",
Or<[
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
TFL_OperandIsUnrankedPred<n>,
CPred<"$_op.getOperand(" # n #
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> : class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> :
PredOpTrait<"operand " # n # " is at most " # m # "-D", PredOpTrait<"operand " # n # " is at most " # m # "-D",
Or<[TFL_OperandIsUnrankedPred<n>, Or<[TFL_OperandIsUnrankedPred<n>,
@ -496,7 +504,7 @@ def TFL_Convolution2DTransposeBiasOp :
NoSideEffect, NoSideEffect,
TFL_OperandHasRank<0, 4>, TFL_OperandHasRank<0, 4>,
TFL_OperandHasRank<1, 4>, TFL_OperandHasRank<1, 4>,
TFL_OperandHasRankLessThanOrEqualTo<2, 1> TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<2, 1>
]> { ]> {
let summary = " Transpose convolution with bias operator"; let summary = " Transpose convolution with bias operator";

View File

@ -2046,6 +2046,14 @@ func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tenso
// ----- // -----
func @testConvolution2DTransposeNoBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> {
%cst = constant unit
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32>
return %0 : tensor<1x64x84x32xf32>
}
// -----
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> { func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
%cst = constant unit %cst = constant unit
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}} // expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}