Canonicalize TensorFlow Concat op to ConcatV2 op
PiperOrigin-RevId: 315980702 Change-Id: I4d2b405e473e89462c3090ba73e61919f6e2c75c
This commit is contained in:
parent
66f4af41a0
commit
7e8866d18e
@ -1021,24 +1021,6 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
|
|||||||
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
func @concat2Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
|
|
||||||
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
|
||||||
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
|
||||||
return %1 : tensor<2x2xi32>
|
|
||||||
|
|
||||||
// CHECK-LABEL: concat2Tensors
|
|
||||||
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @concat3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
|
||||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
|
||||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
|
||||||
return %1 : tensor<2x3xi32>
|
|
||||||
|
|
||||||
// CHECK-LABEL: concat3Tensors
|
|
||||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
|
||||||
}
|
|
||||||
|
|
||||||
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
|
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||||
|
@ -123,7 +123,6 @@ bool HasSameStaticShapes(Operation* op) {
|
|||||||
// operands are properly supported in declarative rewrite rule specification.
|
// operands are properly supported in declarative rewrite rule specification.
|
||||||
|
|
||||||
DECL_CONVERT_OP(Assert);
|
DECL_CONVERT_OP(Assert);
|
||||||
DECL_CONVERT_OP(Concat);
|
|
||||||
DECL_CONVERT_OP(ConcatV2);
|
DECL_CONVERT_OP(ConcatV2);
|
||||||
DECL_CONVERT_OP(MatMul);
|
DECL_CONVERT_OP(MatMul);
|
||||||
DECL_CONVERT_OP(MatrixDiagV2);
|
DECL_CONVERT_OP(MatrixDiagV2);
|
||||||
@ -184,25 +183,6 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ConvertTFConcatOp::matchAndRewrite(
|
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
|
||||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
|
||||||
|
|
||||||
auto values = tf_concat_op.values();
|
|
||||||
auto output_type = tf_concat_op.output().getType();
|
|
||||||
// Extract axis attribute from constant concat_dims tensor
|
|
||||||
ElementsAttr axis;
|
|
||||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
StringAttr fused_activation_function =
|
|
||||||
StringAttr::get("NONE", rewriter.getContext());
|
|
||||||
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
|
|
||||||
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
|
|
||||||
fused_activation_function);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
|
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
|
||||||
// The value won't change in the new attribute, but if the value is out of
|
// The value won't change in the new attribute, but if the value is out of
|
||||||
// the bound of i32, the function returns a failure.
|
// the bound of i32, the function returns a failure.
|
||||||
@ -756,11 +736,11 @@ void LegalizeTF::runOnFunction() {
|
|||||||
|
|
||||||
// Add the generated patterns to the list.
|
// Add the generated patterns to the list.
|
||||||
populateWithGenerated(context, &patterns);
|
populateWithGenerated(context, &patterns);
|
||||||
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
patterns
|
||||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
|
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
||||||
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
|
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||||
ConvertTFAssertOp, ConvertTFReciprocalOp,
|
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||||
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
||||||
|
|
||||||
// Ophint python converter converted tf node pattern.
|
// Ophint python converter converted tf node pattern.
|
||||||
|
@ -1521,6 +1521,8 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
|
|||||||
let verifier = [{
|
let verifier = [{
|
||||||
return Verify(*this);
|
return Verify(*this);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> {
|
def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> {
|
||||||
|
@ -873,6 +873,11 @@ static LogicalResult Verify(OpT op) {
|
|||||||
/*mask_one_dim=*/true, op.getOperation());
|
/*mask_one_dim=*/true, op.getOperation());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||||
|
MLIRContext *context) {
|
||||||
|
results.insert<ConvertToConcatV2>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConcatOffsetOp
|
// ConcatOffsetOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -133,6 +133,16 @@ func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x3
|
|||||||
// CHECK: return %arg0
|
// CHECK: return %arg0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testConcatCanonicalization
|
||||||
|
func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: %[[AXIS:.*]] = "tf.Const"
|
||||||
|
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||||
|
|
||||||
|
// CHECK: "tf.ConcatV2"(%arg0, %arg1, %[[AXIS]])
|
||||||
|
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||||
|
return %1 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: testLogOfSoftmax
|
// CHECK-LABEL: testLogOfSoftmax
|
||||||
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||||
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||||
|
@ -83,6 +83,13 @@ def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg),
|
|||||||
def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
|
def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
|
||||||
(TF_BitcastOp $arg)>;
|
(TF_BitcastOp $arg)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Concat op patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def ConvertToConcatV2 : Pat<(TF_ConcatOp $axis, $inputs),
|
||||||
|
(TF_ConcatV2Op $inputs, $axis)>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conj op patterns.
|
// Conj op patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
Loading…
Reference in New Issue
Block a user