Fix TF_ConcatV2Op conversion pattern when the axis is a I64 Tensor.
PiperOrigin-RevId: 312201848 Change-Id: I55fcd3b514e9da905d0687d7c66e4da49c178ea5
This commit is contained in:
parent
d98a0e6017
commit
63926472df
@ -1048,6 +1048,15 @@ func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concatv2I64Axis(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i64> } : () -> tensor<i64>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i64>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concatv2I64Axis
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @resize_with_bilinear(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.ResizeBilinear"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
@ -202,6 +203,26 @@ LogicalResult ConvertTFConcatOp::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
|
||||
// The value won't change in the new attribute, but if the value is out of
|
||||
// the bound of i32, the function returns a failure.
|
||||
LogicalResult ConvertToI32Attr(IntegerAttr attr, IntegerAttr* attr_i32) {
|
||||
if (attr.getType().isInteger(/*width=*/32)) {
|
||||
*attr_i32 = attr;
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t value = attr.getInt();
|
||||
if (value > std::numeric_limits<int>::max() ||
|
||||
value < std::numeric_limits<int>::min()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
*attr_i32 = IntegerAttr::get(
|
||||
IntegerType::get(/*width=*/32, attr.getContext()), value);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatV2Op>(op);
|
||||
@ -211,12 +232,16 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite(
|
||||
// Extract axis attribute from constant axis tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.axis(), m_Constant(&axis))) return failure();
|
||||
IntegerAttr axis_int = ExtractSingleElementAsInteger(axis);
|
||||
|
||||
// "axis" operand could be a i64 tensor. Resolve it here.
|
||||
IntegerAttr axis_i32;
|
||||
if (failed(ConvertToI32Attr(axis_int, &axis_i32))) return failure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<ConcatenationOp>(
|
||||
op, output_type, values, ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
op, output_type, values, axis_i32, fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user