Fix convolution_2d_transpose_bias without bias
PiperOrigin-RevId: 308272528 Change-Id: I88547df361b1c0dac782ebdeb4c483b1e3ab8c27
This commit is contained in:
parent
4ce3b6c63b
commit
3184d08434
@ -237,6 +237,14 @@ class TFL_TFTypesWithSameBits<int i, int j, int num> :
|
||||
Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa<mlir::TF::Quint" # num # "Type>()">,
|
||||
CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>;
|
||||
|
||||
class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[
|
||||
CPred<"$_op.getOperand(" # n # ").getType().isa<NoneType>()">,
|
||||
TFL_OperandIsUnrankedPred<n>,
|
||||
CPred<"$_op.getOperand(" # n #
|
||||
").getType().cast<ShapedType>().getRank() <= " # m>]>>;
|
||||
|
||||
class TFL_OperandHasRankLessThanOrEqualTo<int n, int m> :
|
||||
PredOpTrait<"operand " # n # " is at most " # m # "-D",
|
||||
Or<[TFL_OperandIsUnrankedPred<n>,
|
||||
@ -496,7 +504,7 @@ def TFL_Convolution2DTransposeBiasOp :
|
||||
NoSideEffect,
|
||||
TFL_OperandHasRank<0, 4>,
|
||||
TFL_OperandHasRank<1, 4>,
|
||||
TFL_OperandHasRankLessThanOrEqualTo<2, 1>
|
||||
TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<2, 1>
|
||||
]> {
|
||||
let summary = " Transpose convolution with bias operator";
|
||||
|
||||
|
@ -2046,6 +2046,14 @@ func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tenso
|
||||
|
||||
// -----
|
||||
|
||||
func @testConvolution2DTransposeNoBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> {
|
||||
%cst = constant unit
|
||||
%0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32>
|
||||
return %0 : tensor<1x64x84x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> {
|
||||
%cst = constant unit
|
||||
// expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}
|
||||
|
Loading…
x
Reference in New Issue
Block a user