From 0f17988fd92829a577b2cf715a137377d68f9c09 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 28 Aug 2020 03:28:47 -0700 Subject: [PATCH] Support all tf.MatMul in the TFL legalization pass PiperOrigin-RevId: 328911250 Change-Id: I70bcad27c107c069ca7c1daac96d33cc962b8be1 --- .../compiler/mlir/lite/tests/legalize-tf.mlir | 39 +++++++++++++++- .../mlir/lite/transforms/legalize_tf.cc | 46 +++++++++++++------ 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index d02e4e705f4..ba98e760590 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1029,14 +1029,49 @@ func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor, tensor<2xi32>, tensor) -> (tensor<1x4x2x3xf32>, tensor<1x4x1x3xf32>) } -func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { +func @matmul(%arg0: tensor<40x37xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = false} : +(tensor<40x37xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul +// CHECK: %[[CST:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_0:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%arg0, %[[ARG]], %[[CST_0]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_a(%arg0: tensor<37x40xf32>, %arg1: tensor<37x40xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = false} : +(tensor<37x40xf32>, tensor<37x40xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_a +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_1:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_1:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_2:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %[[ARG_1]], %[[CST_2]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + +func @matmul_transposed_b(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = false, transpose_b = true} : (tensor<40x37xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> return %0 : tensor<40x40xf32> -// CHECK-LABEL: matmul_transposed +// CHECK-LABEL: matmul_transposed_b // CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> } +func @matmul_transposed_ab(%arg0: tensor<37x40xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { + %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "MatMul", transpose_a = true, transpose_b = true} : +(tensor<37x40xf32>, tensor<40x37xf32>) -> tensor<40x40xf32> + return %0 : tensor<40x40xf32> +// CHECK-LABEL: matmul_transposed_ab +// CHECK: %[[CST_0:.*]] = constant dense<[1, 0]> : tensor<2xi32> +// CHECK: %[[ARG_0:.*]] = "tfl.transpose"(%arg0, %[[CST_0]]) : (tensor<37x40xf32>, tensor<2xi32>) -> tensor<40x37xf32> +// CHECK: %[[CST_1:.*]] = constant unit +// CHECK: "tfl.fully_connected"(%[[ARG_0]], %arg1, %[[CST_1]]) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> +} + func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { %0 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor) -> tensor<2x3xi32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index b31da15c35f..bc460dbce2a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -66,7 +66,6 @@ namespace TFL { // The actual LegalizeTF Pass. namespace { -using xla::Status; using xla::StatusOr; constexpr char kUnidirectionalSequenceLstm[] = "tf.UnidirectionalSequenceLstm"; @@ -232,26 +231,47 @@ LogicalResult ConvertTFConcatV2Op::matchAndRewrite( return success(); } -// The following is effectively: -// def : Pat< -// (TF_MatMulOp $a, $b, ConstBoolAttrFalse:$transpose_a, -// ConstBoolAttrTrue:$transpose_b), -// (TFL_FullyConnectedOp:$__0 $a, $b, -// NoInput.pattern, TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrFalse)>; LogicalResult ConvertTFMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_matmul_op = cast(op); - if (tf_matmul_op.transpose_a()) return failure(); - if (!tf_matmul_op.transpose_b()) return failure(); + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + auto transpose = [&](Value input) -> std::pair { + RankedTensorType type = + input.getType().dyn_cast_or_null(); + if (!type || type.getRank() != 2) return {failure(), nullptr}; + + auto permute_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI32Type()), {1, 0}); + auto permute = rewriter.create( + op->getLoc(), permute_attr.getType(), permute_attr); + llvm::SmallVector new_shape{type.getShape()[1], + type.getShape()[0]}; + auto output = rewriter.create( + op->getLoc(), RankedTensorType::get(new_shape, type.getElementType()), + input, permute); + return {success(), output}; + }; + + // TODO(jpienaar): Remove once handled via dailect conversion. + if (tf_matmul_op.transpose_a()) { + LogicalResult result = success(); + std::tie(result, lhs) = transpose(lhs); + if (failed(result)) return failure(); + } + if (!tf_matmul_op.transpose_b()) { + LogicalResult result = success(); + std::tie(result, rhs) = transpose(rhs); + if (failed(result)) return failure(); + } Type output_type = tf_matmul_op.getResult().getType(); - // TODO(jpienaar): Follow up post shuffle discussion. auto no_input = rewriter.create( op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); auto fc_op = rewriter.create( - op->getLoc(), ArrayRef{output_type}, op->getOperand(0), - op->getOperand(1), no_input, rewriter.getStringAttr("NONE"), - rewriter.getStringAttr("DEFAULT"), rewriter.getBoolAttr(false)); + op->getLoc(), ArrayRef{output_type}, lhs, rhs, no_input, + rewriter.getStringAttr("NONE"), rewriter.getStringAttr("DEFAULT"), + rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {fc_op.getResult(0)}); return success(); }