From b05de9d975cdddfa0c0dfd163a243e41131b7a8f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 5 Sep 2019 19:02:33 -0700 Subject: [PATCH] BatchMatMul conversion implemented PiperOrigin-RevId: 267505822 --- tensorflow/compiler/mlir/lite/BUILD | 2 + .../mlir/lite/tests/unroll-batch-matmul.mlir | 223 ++++++++++++ .../mlir/lite/transforms/prepare_tf.cc | 6 + .../lite/transforms/unroll_batch_matmul.cc | 328 ++++++++++++++++++ .../lite/transforms/unroll_batch_matmul.h | 60 ++++ .../mlir/tensorflow/ir/tf_generated_ops.td | 82 +++++ 6 files changed, 701 insertions(+) create mode 100644 tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir create mode 100644 tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc create mode 100644 tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 99740515a48..225fd393c01 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -198,9 +198,11 @@ cc_library( "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", "transforms/trim_functions_tf.cc", + "transforms/unroll_batch_matmul.cc", ], hdrs = [ "transforms/passes.h", + "transforms/unroll_batch_matmul.h", ], deps = [ ":common", diff --git a/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir new file mode 100644 index 00000000000..09f1dfc9133 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir @@ -0,0 +1,223 @@ +// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s + +func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> + return %0 : tensor<2x3x4x6xf32> + + // CHECK-LABEL: batchMatMulV2TwoDim + // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> + // CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> + // CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> + + // CHECK: return %[[v33]] : tensor<2x3x4x6xf32> +} + +func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + return %0 : tensor<3x4x6xf32> + + // CHECK-LABEL: batchMatMulV2FlatInput + // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> + // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> + // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32> + + // CHECK: return %[[v18]] : tensor<3x4x6xf32> +} + +func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> { + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + return %0 : tensor<4x6xf32> + + // CHECK-LABEL: batchMatMulV2Matrix + // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: return %[[v0]] : tensor<4x6xf32> +} + +func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> + return %0 : tensor<2x3x4x6xf32> + + // CHECK-LABEL: batchMatMulTwoDim + // CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32> + // CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32> + // CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32> + + // CHECK: return %[[v33]] : tensor<2x3x4x6xf32> +} + +func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + return %0 : tensor<3x4x6xf32> + + // CHECK-LABEL: batchMatMulFlatInput + // CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64> + // CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64> + // CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64> + // CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64> + // CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64> + // CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64> + // CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64> + + // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + + // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> + // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + + // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + + // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> + // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32> + + // CHECK: return %[[v18]] : tensor<3x4x6xf32> +} + +func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> { + %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + return %0 : tensor<4x6xf32> + + // CHECK-LABEL: batchMatMulMatrix + // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: return %[[v0]] : tensor<4x6xf32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 7c7983ae254..102887da5ee 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -377,6 +378,11 @@ class ConvertTFDepthwiseConv2dNative void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); + + patterns.insert, + ConvertTFBatchMatMulOp>(&getContext()); + applyPatternsGreedily(func, patterns); + // This pattern was intented to uses TFL QDQs to preserve the quantization // parameters from the TF Quant ops, thus this pattern should run with the // first `applyPatternsGreedily` method, which would otherwise removes the diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc new file mode 100644 index 00000000000..1fde6accb4a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc @@ -0,0 +1,328 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass prepares for legalization to the TFLite dialect by +// converting operations in TensorFlow dialect into operations that can be +// legalized to TensorFlow Lite dialect with simple replacements. The newly +// created operations are in the TensorFlow dialect if the operation can be +// represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is +// used. For example, Conv2D in TFLite which uses OHWI data format for filters +// is not supported in TensorFlow because TensorFlow requires filters in the +// HWIO data format. +// +// Motivation to prepare for the TFLite legalization before the actual +// legalization is to exploit constant folding opportunities in any newly +// created ops by leveraging constant folding support for the TensorFlow ops. +// This way TFLite can be used as a serialization format only and does not +// require access to the TFLite runtime for optimizations as required by the +// TFLite team. + +#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir +#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TFL { + +namespace { +// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out +// of the inputs, matmul them individually, then stack them all back together at +// the end. +struct UnrollBatchMatMulPass : public FunctionPass { + void runOnFunction() override; +}; + +void UnrollBatchMatMulPass::runOnFunction() { + OwningRewritePatternList patterns; + auto func = getFunction(); + + patterns.insert, + ConvertTFBatchMatMulOp>(&getContext()); + applyPatternsGreedily(func, patterns); +} + +} // namespace + +template +TF::ReshapeOp ConvertTFBatchMatMulOp::createReshapeOp( + Value* value, ArrayRef shape, Type elementType, Location loc, + PatternRewriter& rewriter) { + int64_t shape_rank = shape.size(); + auto shapeSpecType = + rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64)); + Type resultType = rewriter.getTensorType(shape, elementType); + auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape); + auto shapeTensor = + rewriter.create(loc, shapeSpecType, constant_attr); + return rewriter.create(loc, resultType, /* tensor = */ value, + /* shape = */ shapeTensor); +} + +template +std::vector ConvertTFBatchMatMulOp::sliceInput( + Value* value, int batch_size, Location loc, PatternRewriter& rewriter) { + RankedTensorType tensorType = value->getType().cast(); + Type elementType = tensorType.getElementType(); + + int rank = tensorType.getShape().size(); + int num_rows = tensorType.getShape()[rank - 2]; + int num_cols = tensorType.getShape()[rank - 1]; + + // Reshape to rank-3 Tensor with first dimension as the batch size. + auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols}, + elementType, loc, rewriter); + + SmallVector sliceSize = {1, num_rows, num_cols}; + + std::vector sliced; + Type int64Type = rewriter.getIntegerType(64); + Type sliceResultType = rewriter.getTensorType(sliceSize, elementType); + + // Slice along each batch index and remember the slice output for future + // use. + for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + auto vector3Type = rewriter.getTensorType({3}, int64Type); + + auto begin_attr = + DenseElementsAttr::get(vector3Type, {batch_idx, 0, 0}); + auto size_attr = DenseElementsAttr::get(vector3Type, sliceSize); + auto sliceOp = rewriter.create( + loc, sliceResultType, + /* input = */ reshapeOp.output(), + /* begin = */ + rewriter.create(loc, vector3Type, begin_attr), + /* size = */ + rewriter.create(loc, vector3Type, size_attr)); + + // Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows, + // num_cols] + auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols}, + elementType, loc, rewriter); + + sliced.emplace_back(squeezeOp.output()); + } + return sliced; +} + +template +TF::TransposeOp ConvertTFBatchMatMulOp::createTransposeOp( + Value* value, Location loc, PatternRewriter& rewriter) { + auto valueType = value->getType().cast(); + auto shape = valueType.getShape(); + int dims = shape.size(); + + std::vector perm(dims); + for (int i = 0; i < dims - 2; i++) { + perm[i] = i; + } + perm[dims - 2] = dims - 1; + perm[dims - 1] = dims - 2; + + auto perm_type = rewriter.getTensorType({static_cast(perm.size())}, + rewriter.getIntegerType(32)); + + auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm)); + auto perm_op = rewriter.create(loc, perm_type, perm_attr); + + std::vector transposed_shape(shape.begin(), shape.end()); + int64_t r = transposed_shape[dims - 1]; + int64_t c = transposed_shape[dims - 2]; + + transposed_shape[dims - 1] = c; + transposed_shape[dims - 2] = r; + + auto transposed_type = + rewriter.getTensorType(transposed_shape, valueType.getElementType()); + return rewriter.create(loc, transposed_type, value, perm_op); +} + +template +TF::PackOp ConvertTFBatchMatMulOp::createMatMulOps( + const std::vector& sliced_lhs, + const std::vector& sliced_rhs, const tensorflow::MatMulBCast& bcast, + int rows, int cols, Type elementType, Location loc, + PatternRewriter& rewriter) { + auto matmulType = rewriter.getTensorType({rows, cols}, elementType); + + std::vector matmuls; + for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) { + int lhs_batch_idx, rhs_batch_idx; + if (bcast.IsBroadcastingRequired()) { + lhs_batch_idx = bcast.x_batch_indices()[batch_idx]; + rhs_batch_idx = bcast.y_batch_indices()[batch_idx]; + } else { + lhs_batch_idx = batch_idx; + rhs_batch_idx = batch_idx; + } + auto matmul = rewriter.create( + loc, matmulType, + /* a = */ sliced_lhs[lhs_batch_idx], + /* b = */ sliced_rhs[rhs_batch_idx], + /* transpose_a = */ rewriter.getBoolAttr(false), + /* transpose_b = */ rewriter.getBoolAttr(false)); + matmuls.emplace_back(matmul.product()); + } + + // Combine the result of each individual MatMul into a rank-3 Tensor. + Type packedType = rewriter.getTensorType( + {bcast.output_batch_size(), rows, cols}, elementType); + + return rewriter.create( + loc, packedType, + /* values = */ matmuls, + /* N = */ rewriter.getI64IntegerAttr(matmuls.size()), + /* axis = */ rewriter.getI64IntegerAttr(0)); +} + +template +PatternMatchResult ConvertTFBatchMatMulOp::matchAndRewrite( + BatchMatMulOpType op, PatternRewriter& rewriter) const { + Value* input_lhs = op.x(); + Value* input_rhs = op.y(); + + if (!input_lhs->getType().isa()) { + // LHS must be a ranked tensor type + return this->matchFailure(); + } + if (!input_rhs->getType().isa()) { + // RHS must be a ranked tensor type + return this->matchFailure(); + } + + auto lhs_type = input_lhs->getType().cast(); + auto rhs_type = input_rhs->getType().cast(); + + auto elementType = lhs_type.getElementType(); + + if (elementType != rhs_type.getElementType()) { + // The element type of LHS must be the same with element type of RHS + return this->matchFailure(); + } + + auto lhs_shape = lhs_type.getShape(); + auto rhs_shape = rhs_type.getShape(); + + Location loc = op.getLoc(); + + // Transpose LHS input if necessary. + if (op.adj_x()) { + input_lhs = createTransposeOp(input_lhs, loc, rewriter); + + lhs_type = input_lhs->getType().cast(); + lhs_shape = lhs_type.getShape(); + } + + // Transpose RHS input if necessary. + if (op.adj_y()) { + input_rhs = createTransposeOp(input_rhs, loc, rewriter); + + rhs_type = input_rhs->getType().cast(); + rhs_shape = rhs_type.getShape(); + } + + // Ensure that input ranks are at least 2 and batch shapes are + // broadcastable. + const int dims_a = lhs_shape.size(); + const int dims_b = rhs_shape.size(); + if (dims_a < 2 || dims_b < 2) { + // Both inputs must have rank >= 2 + return this->matchFailure(); + } + + if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) { + // Input dimensions must be compatible for multipication. + return this->matchFailure(); + } + + if (dims_a == 2 && dims_b == 2) { + // When both inputs are matrices, just replace the op to a matmul op. + Type resultType = + rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType); + rewriter.replaceOpWithNewOp( + op, resultType, + /* a = */ input_lhs, + /* b = */ input_rhs, + /* transpose_a = */ rewriter.getBoolAttr(false), + /* transpose_b = */ rewriter.getBoolAttr(false)); + return this->matchSuccess(); + } + + tensorflow::MatMulBCast bcast(absl::InlinedVector( + lhs_shape.begin(), lhs_shape.end()), + absl::InlinedVector( + rhs_shape.begin(), rhs_shape.end())); + + if (!bcast.IsValid()) { + // Input batch dimensions must be broadcastable + return this->matchFailure(); + } + + // Compute slices for each batch in the LHS and RHS. + std::vector sliced_lhs = + sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter); + std::vector sliced_rhs = + sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter); + + // Compute (single batch) MatMul for each output batch. The MatMul outputs + // are then packed together into one output Tensor. + auto packOp = + createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2], + rhs_shape[dims_b - 1], elementType, loc, rewriter); + + // Reshape the rank-3 Tensor into the correct output shape. + const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes(); + std::vector resultShape(resultBatchShape.begin(), + resultBatchShape.end()); + resultShape.push_back(lhs_shape[dims_a - 2]); + resultShape.push_back(rhs_shape[dims_b - 1]); + + auto reshapeOp = + createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter); + rewriter.replaceOp(op, reshapeOp.output()); + return this->matchSuccess(); +} + +static PassRegistration pass( + "tfl-unroll-batch-matmul", + "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops."); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h new file mode 100644 index 00000000000..d4b46eabf7d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TFL { + +// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not +// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape, +// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops. +template +class ConvertTFBatchMatMulOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef shape, + Type elementType, Location loc, + PatternRewriter& rewriter); + + static std::vector sliceInput(Value* value, int batch_size, + Location loc, + PatternRewriter& rewriter); + + static TF::TransposeOp createTransposeOp(Value* value, Location loc, + PatternRewriter& rewriter); + + static TF::PackOp createMatMulOps(const std::vector& sliced_lhs, + const std::vector& sliced_rhs, + const tensorflow::MatMulBCast& bcast, + int rows, int cols, Type elementType, + Location loc, PatternRewriter& rewriter); + + PatternMatchResult matchAndRewrite(BatchMatMulOpType op, + PatternRewriter& rewriter) const override; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 153ac5346b9..8facd952b66 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -261,6 +261,88 @@ window in `value`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> { + let summary = "Multiplies slices of two tensors in batches."; + + let description = [{ +Multiplies all slices of `Tensor` `x` and `y` (each slice can be +viewed as an element of a batch), and arranges the individual results +in a single output tensor of the same batch size. Each of the +individual slices can optionally be adjointed (to adjoint a matrix +means to transpose and conjugate it) before multiplication by setting +the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + +The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +and `[..., r_y, c_y]`. + +The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: + + r_o = c_x if adj_x else r_x + c_o = r_y if adj_y else c_y + +It is computed as: + + output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y, + + DefaultValuedAttr:$adj_x, + DefaultValuedAttr:$adj_y + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> { + let summary = "Multiplies slices of two tensors in batches."; + + let description = [{ +Multiplies all slices of `Tensor` `x` and `y` (each slice can be +viewed as an element of a batch), and arranges the individual results +in a single output tensor of the same batch size. Each of the +individual slices can optionally be adjointed (to adjoint a matrix +means to transpose and conjugate it) before multiplication by setting +the `adj_x` or `adj_y` flag to `True`, which are by default `False`. + +The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +and `[..., r_y, c_y]`. + +The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: + + r_o = c_x if adj_x else r_x + c_o = r_y if adj_y else c_y + +It is computed as: + + output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) + +*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More +about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y, + + DefaultValuedAttr:$adj_x, + DefaultValuedAttr:$adj_y + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> { let summary = "BatchToSpace for N-D tensors of type T.";