Canonicalize TensorFlow Concat op to ConcatV2 op

PiperOrigin-RevId: 315980702
Change-Id: I4d2b405e473e89462c3090ba73e61919f6e2c75c
This commit is contained in:
Smit Hinsu 2020-06-11 14:39:36 -07:00 committed by TensorFlower Gardener
parent 66f4af41a0
commit 7e8866d18e
6 changed files with 30 additions and 44 deletions

View File

@ -1021,24 +1021,6 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @concat2Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
// CHECK-LABEL: concat2Tensors
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
}
func @concat3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concat3Tensors
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>

View File

@ -123,7 +123,6 @@ bool HasSameStaticShapes(Operation* op) {
// operands are properly supported in declarative rewrite rule specification.
DECL_CONVERT_OP(Assert);
DECL_CONVERT_OP(Concat);
DECL_CONVERT_OP(ConcatV2);
DECL_CONVERT_OP(MatMul);
DECL_CONVERT_OP(MatrixDiagV2);
@ -184,25 +183,6 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
return failure();
}
LogicalResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
fused_activation_function);
return success();
}
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
// The value won't change in the new attribute, but if the value is out of
// the bound of i32, the function returns a failure.
@ -756,12 +736,12 @@ void LegalizeTF::runOnFunction() {
// Add the generated patterns to the list.
populateWithGenerated(context, &patterns);
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
patterns
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
// Ophint python converter converted tf node pattern.
patterns.insert<LegalizeUnidirectionalSequenceLstm,

View File

@ -1521,6 +1521,8 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
}
def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> {

View File

@ -873,6 +873,11 @@ static LogicalResult Verify(OpT op) {
/*mask_one_dim=*/true, op.getOperation());
}
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ConvertToConcatV2>(context);
}
//===----------------------------------------------------------------------===//
// ConcatOffsetOp
//===----------------------------------------------------------------------===//

View File

@ -133,6 +133,16 @@ func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x3
// CHECK: return %arg0
}
// CHECK-LABEL: testConcatCanonicalization
func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
// CHECK: %[[AXIS:.*]] = "tf.Const"
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
// CHECK: "tf.ConcatV2"(%arg0, %arg1, %[[AXIS]])
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
// CHECK-LABEL: testLogOfSoftmax
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>

View File

@ -83,6 +83,13 @@ def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg),
def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
(TF_BitcastOp $arg)>;
//===----------------------------------------------------------------------===//
// Concat op patterns.
//===----------------------------------------------------------------------===//
def ConvertToConcatV2 : Pat<(TF_ConcatOp $axis, $inputs),
(TF_ConcatV2Op $inputs, $axis)>;
//===----------------------------------------------------------------------===//
// Conj op patterns.
//===----------------------------------------------------------------------===//