From 63926472df4f777b43146c608a0027a42569fe57 Mon Sep 17 00:00:00 2001 From: Jing Pu Date: Mon, 18 May 2020 19:34:59 -0700 Subject: [PATCH] Fix TF_ConcatV2Op conversion pattern when the axis is a I64 Tensor. PiperOrigin-RevId: 312201848 Change-Id: I55fcd3b514e9da905d0687d7c66e4da49c178ea5 --- .../compiler/mlir/lite/tests/legalize-tf.mlir | 9 ++++++ .../mlir/lite/transforms/legalize_tf.cc | 29 +++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 15b6bf56b7a..15c73d2db2c 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2 // CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> } +func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { + %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> + return %1 : tensor<2x3xi32> + +// CHECK-LABEL: concatv2I64Axis +// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32> +} + func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { %0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index ab4c4f5c4cf..bfcbc190638 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" @@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite( return success(); } +// Converts any IntegerAttr to an IntegerAttr of an i32 type. +// The value won't change in the new attribute, but if the value is out of +// the bound of i32, the function returns a failure. +LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) { + if (attr.getType().isInteger(/*width=*/32)) { + *attr_i32 = attr; + return success(); + } + + int64_t value = attr.getInt(); + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + return failure(); + } + + *attr_i32 = IntegerAttr::get( + IntegerType::get(/*width=*/32, attr.getContext()), value); + return success(); +} + LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concat_op = cast(op); @@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( // Extract axis attribute from constant axis tensor ElementsAttr axis; if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure(); + IntegerAttr axis_int = ExtractSingleElementAsInteger(axis); + + // "axis" operand could be a i64 tensor. Resolve it here. + IntegerAttr axis_i32; + if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure(); StringAttr fused_activation_function = StringAttr::get("NONE", rewriter.getContext()); rewriter.replaceOpWithNewOp( - op, output_type, values, ExtractSingleElementAsInteger(axis), - fused_activation_function); + op, output_type, values, axis_i32, fused_activation_function); return success(); }