Fix TF_ConcatV2Op conversion pattern when the axis is a I64 Tensor.

PiperOrigin-RevId: 312201848
Change-Id: I55fcd3b514e9da905d0687d7c66e4da49c178ea5
This commit is contained in:
Jing Pu 2020-05-18 19:34:59 -07:00 committed by TensorFlower Gardener
parent d98a0e6017
commit 63926472df
2 changed files with 36 additions and 2 deletions

View File

@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concatv2I64Axis
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite(
return success();
}
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
// The value won't change in the new attribute, but if the value is out of
// the bound of i32, the function returns a failure.
LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
if (attr.getType().isInteger(/*width=*/32)) {
*attr_i32 = attr;
return success();
}
int64_t value = attr.getInt();
if (value > std::numeric_limits<int>::max() ||
value < std::numeric_limits<int>::min()) {
return failure();
}
*attr_i32 = IntegerAttr::get(
IntegerType::get(/*width=*/32, attr.getContext()), value);
return success();
}
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
// Extract axis attribute from constant axis tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
// "axis" operand could be a i64 tensor. Resolve it here.
IntegerAttr axis_i32;
if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<ConcatenationOp>(
op, output_type, values, ExtractSingleElementAsInteger(axis),
fused_activation_function);
op, output_type, values, axis_i32, fused_activation_function);
return success();
}