From 3184d08434e7539940d989262a9120a228a8ac3c Mon Sep 17 00:00:00 2001 From: Logan Chien Date: Fri, 24 Apr 2020 09:53:48 -0700 Subject: [PATCH] Fix convolution_2d_transpose_bias without bias PiperOrigin-RevId: 308272528 Change-Id: I88547df361b1c0dac782ebdeb4c483b1e3ab8c27 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 10 +++++++++- tensorflow/compiler/mlir/lite/tests/ops.mlir | 8 ++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 6cb8011a8d6..9031d54070c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -237,6 +237,14 @@ class TFL_TFTypesWithSameBits : Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; +class TFL_OperandIsNoneOrHasRankLessThanOrEqualTo : + PredOpTrait<"operand " # n # " is at most " # m # "-D", + Or<[ + CPred<"$_op.getOperand(" # n # ").getType().isa()">, + TFL_OperandIsUnrankedPred, + CPred<"$_op.getOperand(" # n # + ").getType().cast().getRank() <= " # m>]>>; + class TFL_OperandHasRankLessThanOrEqualTo : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[TFL_OperandIsUnrankedPred, @@ -496,7 +504,7 @@ def TFL_Convolution2DTransposeBiasOp : NoSideEffect, TFL_OperandHasRank<0, 4>, TFL_OperandHasRank<1, 4>, - TFL_OperandHasRankLessThanOrEqualTo<2, 1> + TFL_OperandIsNoneOrHasRankLessThanOrEqualTo<2, 1> ]> { let summary = " Transpose convolution with bias operator"; diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 697e93e582c..052e2cab64d 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -2046,6 +2046,14 @@ func @testConvolution2DTransposeBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tenso // ----- +func @testConvolution2DTransposeNoBias(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>) -> tensor<1x64x84x32xf32> { + %cst = constant unit + %0 = "tfl.convolution_2d_transpose_bias"(%arg0, %arg1, %cst) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, none) -> tensor<1x64x84x32xf32> + return %0 : tensor<1x64x84x32xf32> +} + +// ----- + func @testTransposeConvBadOutputRank(%arg0: tensor<4xi32>, %arg1: tensor<32x4x4x128xf32>, %arg2: tensor<1x32x42x128xf32>) -> tensor<64x84x32xf32> { %cst = constant unit // expected-error @+1 {{expect output type has rank = 4, got output type tensor<64x84x32xf32>}}