Fix convolution_2d_transpose_bias without bias
PiperOrigin-RevId: 308272528 Change-Id: I88547df361b1c0dac782ebdeb4c483b1e3ab8c27
This commit is contained in:
parent
4ce3b6c63b
commit
3184d08434
@ -237,6 +237,14 @@ class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
|||||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||||
|
|
||||||
|
class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
||||||
|
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||||
|
Or<[
|
||||||
|
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||||
|
TFL_OperandIsUnrankedPred<n>,
|
||||||
|
CPred<"$_op.getOperand(" # n #
|
||||||
|
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||||
|
|
||||||
class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> :
|
class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> :
|
||||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||||
@ -496,7 +504,7 @@ def TFL_Convolution2DTransposeBiasOp :
|
|||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
TFL_OperandHasRank<0, 4>,
|
TFL_OperandHasRank<0, 4>,
|
||||||
TFL_OperandHasRank<1, 4>,
|
TFL_OperandHasRank<1, 4>,
|
||||||
TFL_OperandHasRankLessThanOrEqualTo<2, 1>
|
TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<2, 1>
|
||||||
]> {
|
]> {
|
||||||
let summary = " Transpose convolution with bias operator";
|
let summary = " Transpose convolution with bias operator";
|
||||||
|
|
||||||
|
@ -2046,6 +2046,14 @@ func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tenso
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @testConvolution2DTransposeNoBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32>
|
||||||
|
return %0 : tensor<1x64x84x32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
|
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
|
||||||
%cst = constant unit
|
%cst = constant unit
|
||||||
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
|
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user