BatchMatMul conversion implemented
PiperOrigin-RevId: 267505822
This commit is contained in:
parent
484e8acedc
commit
b05de9d975
@ -198,9 +198,11 @@ cc_library(
|
||||
"transforms/prepare_composite_functions_tf.cc",
|
||||
"transforms/prepare_tf.cc",
|
||||
"transforms/trim_functions_tf.cc",
|
||||
"transforms/unroll_batch_matmul.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
"transforms/unroll_batch_matmul.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
|
223
tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
Normal file
223
tensorflow/compiler/mlir/lite/tests/unroll-batch-matmul.mlir
Normal file
@ -0,0 +1,223 @@
|
||||
// RUN: tf-opt -tfl-unroll-batch-matmul %s | FileCheck %s
|
||||
|
||||
func @batchMatMulV2TwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2TwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
|
||||
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
|
||||
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2FlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulV2Matrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
return %0 : tensor<4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulV2Matrix
|
||||
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulTwoDim(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<2x3x5x6xf32>) -> tensor<2x3x4x6xf32>
|
||||
return %0 : tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulTwoDim
|
||||
// CHECK: %[[cst:.*]] = constant dense<[6, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[6, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[3, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[4, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[5, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_9:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_10:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_11:.*]] = constant dense<[2, 3, 4, 6]> : tensor<4xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<2x3x4x5xf32>, tensor<3xi64>) -> tensor<6x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v7:.*]] = "tf.Slice"(%[[v0]], %[[cst_6]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Reshape"(%[[v7]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Slice"(%[[v0]], %[[cst_7]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Reshape"(%[[v9]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Slice"(%[[v0]], %[[cst_8]], %[[cst_0]]) : (tensor<6x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Reshape"(%[[v11]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<2x3x5x6xf32>, tensor<3xi64>) -> tensor<6x5x6xf32>
|
||||
// CHECK: %[[v14:.*]] = "tf.Slice"(%[[v13]], %[[cst_3]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.Reshape"(%[[v14]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.Slice"(%[[v13]], %[[cst_4]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v17:.*]] = "tf.Reshape"(%[[v16]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Slice"(%[[v13]], %[[cst_5]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v19:.*]] = "tf.Reshape"(%[[v18]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v20:.*]] = "tf.Slice"(%[[v13]], %[[cst_6]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v21:.*]] = "tf.Reshape"(%[[v20]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v22:.*]] = "tf.Slice"(%[[v13]], %[[cst_7]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v23:.*]] = "tf.Reshape"(%[[v22]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v24:.*]] = "tf.Slice"(%[[v13]], %[[cst_8]], %[[cst_9]]) : (tensor<6x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v25:.*]] = "tf.Reshape"(%[[v24]], %[[cst_10]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v26:.*]] = "tf.MatMul"(%[[v2]], %[[v15]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v27:.*]] = "tf.MatMul"(%[[v4]], %[[v17]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v28:.*]] = "tf.MatMul"(%[[v6]], %[[v19]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v29:.*]] = "tf.MatMul"(%[[v8]], %[[v21]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v30:.*]] = "tf.MatMul"(%[[v10]], %[[v23]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v31:.*]] = "tf.MatMul"(%[[v12]], %[[v25]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v32:.*]] = "tf.Pack"(%[[v26]], %[[v27]], %[[v28]], %[[v29]], %[[v30]], %[[v31]]) {N = 6 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<6x4x6xf32>
|
||||
// CHECK: %[[v33:.*]] = "tf.Reshape"(%[[v32]], %[[cst_11]]) : (tensor<6x4x6xf32>, tensor<4xi64>) -> tensor<2x3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v33]] : tensor<2x3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32>
|
||||
return %0 : tensor<3x4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulFlatInput
|
||||
// CHECK: %[[cst:.*]] = constant dense<[3, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_0:.*]] = constant dense<[1, 4, 5]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_1:.*]] = constant dense<[4, 5]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_2:.*]] = constant dense<[3, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_3:.*]] = constant dense<0> : tensor<3xi64>
|
||||
// CHECK: %[[cst_4:.*]] = constant dense<[1, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_5:.*]] = constant dense<[2, 0, 0]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_6:.*]] = constant dense<[1, 5, 6]> : tensor<3xi64>
|
||||
// CHECK: %[[cst_7:.*]] = constant dense<[5, 6]> : tensor<2xi64>
|
||||
// CHECK: %[[cst_8:.*]] = constant dense<[3, 4, 6]> : tensor<3xi64>
|
||||
|
||||
// CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32>
|
||||
// CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
// CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32>
|
||||
// CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32>
|
||||
// CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
// CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32>
|
||||
// CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32>
|
||||
|
||||
// CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
|
||||
// CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {N = 3 : i64, axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32>
|
||||
// CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32>
|
||||
|
||||
// CHECK: return %[[v18]] : tensor<3x4x6xf32>
|
||||
}
|
||||
|
||||
func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tensor<4x6xf32> {
|
||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
return %0 : tensor<4x6xf32>
|
||||
|
||||
// CHECK-LABEL: batchMatMulMatrix
|
||||
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
||||
}
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
@ -377,6 +378,11 @@ class ConvertTFDepthwiseConv2dNative
|
||||
void PrepareTFPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
|
||||
// This pattern was intented to uses TFL QDQs to preserve the quantization
|
||||
// parameters from the TF Quant ops, thus this pattern should run with the
|
||||
// first `applyPatternsGreedily` method, which would otherwise removes the
|
||||
|
328
tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
Normal file
328
tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.cc
Normal file
@ -0,0 +1,328 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass prepares for legalization to the TFLite dialect by
|
||||
// converting operations in TensorFlow dialect into operations that can be
|
||||
// legalized to TensorFlow Lite dialect with simple replacements. The newly
|
||||
// created operations are in the TensorFlow dialect if the operation can be
|
||||
// represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op is
|
||||
// used. For example, Conv2D in TFLite which uses OHWI data format for filters
|
||||
// is not supported in TensorFlow because TensorFlow requires filters in the
|
||||
// HWIO data format.
|
||||
//
|
||||
// Motivation to prepare for the TFLite legalization before the actual
|
||||
// legalization is to exploit constant folding opportunities in any newly
|
||||
// created ops by leveraging constant folding support for the TensorFlow ops.
|
||||
// This way TFLite can be used as a serialization format only and does not
|
||||
// require access to the TFLite runtime for optimizations as required by the
|
||||
// TFLite team.
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/unroll_batch_matmul.h"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/Dialect/QuantOps/UniformSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
namespace {
|
||||
// Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
|
||||
// of the inputs, matmul them individually, then stack them all back together at
|
||||
// the end.
|
||||
struct UnrollBatchMatMulPass : public FunctionPass<UnrollBatchMatMulPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void UnrollBatchMatMulPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
|
||||
patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
|
||||
ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
|
||||
applyPatternsGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
|
||||
Value* value, ArrayRef<int64_t> shape, Type elementType, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
int64_t shape_rank = shape.size();
|
||||
auto shapeSpecType =
|
||||
rewriter.getTensorType({shape_rank}, rewriter.getIntegerType(64));
|
||||
Type resultType = rewriter.getTensorType(shape, elementType);
|
||||
auto constant_attr = DenseElementsAttr::get(shapeSpecType, shape);
|
||||
auto shapeTensor =
|
||||
rewriter.create<ConstantOp>(loc, shapeSpecType, constant_attr);
|
||||
return rewriter.create<TF::ReshapeOp>(loc, resultType, /* tensor = */ value,
|
||||
/* shape = */ shapeTensor);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
std::vector<Value*> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
|
||||
Value* value, int batch_size, Location loc, PatternRewriter& rewriter) {
|
||||
RankedTensorType tensorType = value->getType().cast<RankedTensorType>();
|
||||
Type elementType = tensorType.getElementType();
|
||||
|
||||
int rank = tensorType.getShape().size();
|
||||
int num_rows = tensorType.getShape()[rank - 2];
|
||||
int num_cols = tensorType.getShape()[rank - 1];
|
||||
|
||||
// Reshape to rank-3 Tensor with first dimension as the batch size.
|
||||
auto reshapeOp = createReshapeOp(value, {batch_size, num_rows, num_cols},
|
||||
elementType, loc, rewriter);
|
||||
|
||||
SmallVector<int64_t, 3> sliceSize = {1, num_rows, num_cols};
|
||||
|
||||
std::vector<Value*> sliced;
|
||||
Type int64Type = rewriter.getIntegerType(64);
|
||||
Type sliceResultType = rewriter.getTensorType(sliceSize, elementType);
|
||||
|
||||
// Slice along each batch index and remember the slice output for future
|
||||
// use.
|
||||
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
auto vector3Type = rewriter.getTensorType({3}, int64Type);
|
||||
|
||||
auto begin_attr =
|
||||
DenseElementsAttr::get<int64_t>(vector3Type, {batch_idx, 0, 0});
|
||||
auto size_attr = DenseElementsAttr::get<int64_t>(vector3Type, sliceSize);
|
||||
auto sliceOp = rewriter.create<TF::SliceOp>(
|
||||
loc, sliceResultType,
|
||||
/* input = */ reshapeOp.output(),
|
||||
/* begin = */
|
||||
rewriter.create<ConstantOp>(loc, vector3Type, begin_attr),
|
||||
/* size = */
|
||||
rewriter.create<ConstantOp>(loc, vector3Type, size_attr));
|
||||
|
||||
// Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
|
||||
// num_cols]
|
||||
auto squeezeOp = createReshapeOp(sliceOp.output(), {num_rows, num_cols},
|
||||
elementType, loc, rewriter);
|
||||
|
||||
sliced.emplace_back(squeezeOp.output());
|
||||
}
|
||||
return sliced;
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
|
||||
Value* value, Location loc, PatternRewriter& rewriter) {
|
||||
auto valueType = value->getType().cast<RankedTensorType>();
|
||||
auto shape = valueType.getShape();
|
||||
int dims = shape.size();
|
||||
|
||||
std::vector<int32_t> perm(dims);
|
||||
for (int i = 0; i < dims - 2; i++) {
|
||||
perm[i] = i;
|
||||
}
|
||||
perm[dims - 2] = dims - 1;
|
||||
perm[dims - 1] = dims - 2;
|
||||
|
||||
auto perm_type = rewriter.getTensorType({static_cast<int32_t>(perm.size())},
|
||||
rewriter.getIntegerType(32));
|
||||
|
||||
auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
|
||||
auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
|
||||
|
||||
std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
|
||||
int64_t r = transposed_shape[dims - 1];
|
||||
int64_t c = transposed_shape[dims - 2];
|
||||
|
||||
transposed_shape[dims - 1] = c;
|
||||
transposed_shape[dims - 2] = r;
|
||||
|
||||
auto transposed_type =
|
||||
rewriter.getTensorType(transposed_shape, valueType.getElementType());
|
||||
return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
|
||||
const std::vector<Value*>& sliced_lhs,
|
||||
const std::vector<Value*>& sliced_rhs, const tensorflow::MatMulBCast& bcast,
|
||||
int rows, int cols, Type elementType, Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
auto matmulType = rewriter.getTensorType({rows, cols}, elementType);
|
||||
|
||||
std::vector<Value*> matmuls;
|
||||
for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
|
||||
int lhs_batch_idx, rhs_batch_idx;
|
||||
if (bcast.IsBroadcastingRequired()) {
|
||||
lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
|
||||
rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
|
||||
} else {
|
||||
lhs_batch_idx = batch_idx;
|
||||
rhs_batch_idx = batch_idx;
|
||||
}
|
||||
auto matmul = rewriter.create<TF::MatMulOp>(
|
||||
loc, matmulType,
|
||||
/* a = */ sliced_lhs[lhs_batch_idx],
|
||||
/* b = */ sliced_rhs[rhs_batch_idx],
|
||||
/* transpose_a = */ rewriter.getBoolAttr(false),
|
||||
/* transpose_b = */ rewriter.getBoolAttr(false));
|
||||
matmuls.emplace_back(matmul.product());
|
||||
}
|
||||
|
||||
// Combine the result of each individual MatMul into a rank-3 Tensor.
|
||||
Type packedType = rewriter.getTensorType(
|
||||
{bcast.output_batch_size(), rows, cols}, elementType);
|
||||
|
||||
return rewriter.create<TF::PackOp>(
|
||||
loc, packedType,
|
||||
/* values = */ matmuls,
|
||||
/* N = */ rewriter.getI64IntegerAttr(matmuls.size()),
|
||||
/* axis = */ rewriter.getI64IntegerAttr(0));
|
||||
}
|
||||
|
||||
template <typename BatchMatMulOpType>
|
||||
PatternMatchResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
|
||||
BatchMatMulOpType op, PatternRewriter& rewriter) const {
|
||||
Value* input_lhs = op.x();
|
||||
Value* input_rhs = op.y();
|
||||
|
||||
if (!input_lhs->getType().isa<RankedTensorType>()) {
|
||||
// LHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
if (!input_rhs->getType().isa<RankedTensorType>()) {
|
||||
// RHS must be a ranked tensor type
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
auto rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
|
||||
auto elementType = lhs_type.getElementType();
|
||||
|
||||
if (elementType != rhs_type.getElementType()) {
|
||||
// The element type of LHS must be the same with element type of RHS
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto lhs_shape = lhs_type.getShape();
|
||||
auto rhs_shape = rhs_type.getShape();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Transpose LHS input if necessary.
|
||||
if (op.adj_x()) {
|
||||
input_lhs = createTransposeOp(input_lhs, loc, rewriter);
|
||||
|
||||
lhs_type = input_lhs->getType().cast<RankedTensorType>();
|
||||
lhs_shape = lhs_type.getShape();
|
||||
}
|
||||
|
||||
// Transpose RHS input if necessary.
|
||||
if (op.adj_y()) {
|
||||
input_rhs = createTransposeOp(input_rhs, loc, rewriter);
|
||||
|
||||
rhs_type = input_rhs->getType().cast<RankedTensorType>();
|
||||
rhs_shape = rhs_type.getShape();
|
||||
}
|
||||
|
||||
// Ensure that input ranks are at least 2 and batch shapes are
|
||||
// broadcastable.
|
||||
const int dims_a = lhs_shape.size();
|
||||
const int dims_b = rhs_shape.size();
|
||||
if (dims_a < 2 || dims_b < 2) {
|
||||
// Both inputs must have rank >= 2
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
|
||||
// Input dimensions must be compatible for multipication.
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
if (dims_a == 2 && dims_b == 2) {
|
||||
// When both inputs are matrices, just replace the op to a matmul op.
|
||||
Type resultType =
|
||||
rewriter.getTensorType({lhs_shape[0], rhs_shape[1]}, elementType);
|
||||
rewriter.replaceOpWithNewOp<TF::MatMulOp>(
|
||||
op, resultType,
|
||||
/* a = */ input_lhs,
|
||||
/* b = */ input_rhs,
|
||||
/* transpose_a = */ rewriter.getBoolAttr(false),
|
||||
/* transpose_b = */ rewriter.getBoolAttr(false));
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
|
||||
lhs_shape.begin(), lhs_shape.end()),
|
||||
absl::InlinedVector<tensorflow::int64, 4>(
|
||||
rhs_shape.begin(), rhs_shape.end()));
|
||||
|
||||
if (!bcast.IsValid()) {
|
||||
// Input batch dimensions must be broadcastable
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
// Compute slices for each batch in the LHS and RHS.
|
||||
std::vector<Value*> sliced_lhs =
|
||||
sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
|
||||
std::vector<Value*> sliced_rhs =
|
||||
sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
|
||||
|
||||
// Compute (single batch) MatMul for each output batch. The MatMul outputs
|
||||
// are then packed together into one output Tensor.
|
||||
auto packOp =
|
||||
createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
|
||||
rhs_shape[dims_b - 1], elementType, loc, rewriter);
|
||||
|
||||
// Reshape the rank-3 Tensor into the correct output shape.
|
||||
const auto& resultBatchShape = bcast.output_batch_shape().dim_sizes();
|
||||
std::vector<int64_t> resultShape(resultBatchShape.begin(),
|
||||
resultBatchShape.end());
|
||||
resultShape.push_back(lhs_shape[dims_a - 2]);
|
||||
resultShape.push_back(rhs_shape[dims_b - 1]);
|
||||
|
||||
auto reshapeOp =
|
||||
createReshapeOp(packOp.output(), resultShape, elementType, loc, rewriter);
|
||||
rewriter.replaceOp(op, reshapeOp.output());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
static PassRegistration<UnrollBatchMatMulPass> pass(
|
||||
"tfl-unroll-batch-matmul",
|
||||
"Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -0,0 +1,60 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/core/util/matmul_bcast.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
// Unroll tf.BatchMatMulV2 op into a sequence of TF ops. Since TFLite does not
|
||||
// support BatchMatMul operation, it unrolls a BatchMatMul op into tf.Reshape,
|
||||
// tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops.
|
||||
template <typename BatchMatMulOpType>
|
||||
class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> {
|
||||
using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern;
|
||||
|
||||
static TF::ReshapeOp createReshapeOp(Value* value, ArrayRef<int64_t> shape,
|
||||
Type elementType, Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static std::vector<Value*> sliceInput(Value* value, int batch_size,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static TF::TransposeOp createTransposeOp(Value* value, Location loc,
|
||||
PatternRewriter& rewriter);
|
||||
|
||||
static TF::PackOp createMatMulOps(const std::vector<Value*>& sliced_lhs,
|
||||
const std::vector<Value*>& sliced_rhs,
|
||||
const tensorflow::MatMulBCast& bcast,
|
||||
int rows, int cols, Type elementType,
|
||||
Location loc, PatternRewriter& rewriter);
|
||||
|
||||
PatternMatchResult matchAndRewrite(BatchMatMulOpType op,
|
||||
PatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNROLL_BATCH_MATMUL_H_
|
@ -261,6 +261,88 @@ window in `value`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect]> {
|
||||
let summary = "Multiplies slices of two tensors in batches.";
|
||||
|
||||
let description = [{
|
||||
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
|
||||
viewed as an element of a batch), and arranges the individual results
|
||||
in a single output tensor of the same batch size. Each of the
|
||||
individual slices can optionally be adjointed (to adjoint a matrix
|
||||
means to transpose and conjugate it) before multiplication by setting
|
||||
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
|
||||
|
||||
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
||||
and `[..., r_y, c_y]`.
|
||||
|
||||
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
|
||||
|
||||
r_o = c_x if adj_x else r_x
|
||||
c_o = r_y if adj_y else c_y
|
||||
|
||||
It is computed as:
|
||||
|
||||
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect]> {
|
||||
let summary = "Multiplies slices of two tensors in batches.";
|
||||
|
||||
let description = [{
|
||||
Multiplies all slices of `Tensor` `x` and `y` (each slice can be
|
||||
viewed as an element of a batch), and arranges the individual results
|
||||
in a single output tensor of the same batch size. Each of the
|
||||
individual slices can optionally be adjointed (to adjoint a matrix
|
||||
means to transpose and conjugate it) before multiplication by setting
|
||||
the `adj_x` or `adj_y` flag to `True`, which are by default `False`.
|
||||
|
||||
The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
||||
and `[..., r_y, c_y]`.
|
||||
|
||||
The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where:
|
||||
|
||||
r_o = c_x if adj_x else r_x
|
||||
c_o = r_y if adj_y else c_y
|
||||
|
||||
It is computed as:
|
||||
|
||||
output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
||||
|
||||
*NOTE*: `BatchMatMulV2` supports broadcasting in the batch dimensions. More
|
||||
about broadcasting
|
||||
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x,
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
|
||||
let summary = "BatchToSpace for N-D tensors of type T.";
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user