Canonicalize TensorFlow Concat op to ConcatV2 op
PiperOrigin-RevId: 315980702 Change-Id: I4d2b405e473e89462c3090ba73e61919f6e2c75c
This commit is contained in:
parent
66f4af41a0
commit
7e8866d18e
@ -1021,24 +1021,6 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
|
||||
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @concat2Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
|
||||
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
return %1 : tensor<2x2xi32>
|
||||
|
||||
// CHECK-LABEL: concat2Tensors
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
}
|
||||
|
||||
func @concat3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concat3Tensors
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||
|
@ -123,7 +123,6 @@ bool HasSameStaticShapes(Operation* op) {
|
||||
// operands are properly supported in declarative rewrite rule specification.
|
||||
|
||||
DECL_CONVERT_OP(Assert);
|
||||
DECL_CONVERT_OP(Concat);
|
||||
DECL_CONVERT_OP(ConcatV2);
|
||||
DECL_CONVERT_OP(MatMul);
|
||||
DECL_CONVERT_OP(MatrixDiagV2);
|
||||
@ -184,25 +183,6 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFConcatOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output().getType();
|
||||
// Extract axis attribute from constant concat_dims tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
||||
return failure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
|
||||
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
|
||||
// The value won't change in the new attribute, but if the value is out of
|
||||
// the bound of i32, the function returns a failure.
|
||||
@ -756,12 +736,12 @@ void LegalizeTF::runOnFunction() {
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
populateWithGenerated(context, &patterns);
|
||||
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
|
||||
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
||||
ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
||||
patterns
|
||||
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
patterns.insert<LegalizeUnidirectionalSequenceLstm,
|
||||
|
@ -1521,6 +1521,8 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> {
|
||||
|
@ -873,6 +873,11 @@ static LogicalResult Verify(OpT op) {
|
||||
/*mask_one_dim=*/true, op.getOperation());
|
||||
}
|
||||
|
||||
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ConvertToConcatV2>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatOffsetOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -133,6 +133,16 @@ func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x3
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConcatCanonicalization
|
||||
func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[AXIS:.*]] = "tf.Const"
|
||||
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
|
||||
// CHECK: "tf.ConcatV2"(%arg0, %arg1, %[[AXIS]])
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
return %1 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogOfSoftmax
|
||||
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
|
@ -83,6 +83,13 @@ def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg),
|
||||
def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
|
||||
(TF_BitcastOp $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Concat op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertToConcatV2 : Pat<(TF_ConcatOp $axis, $inputs),
|
||||
(TF_ConcatV2Op $inputs, $axis)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conj op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
Loading…
Reference in New Issue
Block a user