MLIR TF2XLA bridge: BatchMatMulV2 lowering for dynamic shapes
PiperOrigin-RevId: 308722335 Change-Id: I070f6eafbcb49faa6183817f675c6674199bdad5
This commit is contained in:
parent
b1c4cf2183
commit
7cd479b8fc
@ -146,6 +146,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
|
@ -0,0 +1,93 @@
|
||||
// RUN: tf-opt -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.BatchMatMulV2 op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_basic
|
||||
// CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: [[LHSSHAPE:%.*]] = "shape.shape_of"([[LHS]]) : (tensor<1x4x2xf32>) -> !shape.shape
|
||||
// CHECK: [[RHSSHAPE:%.*]] = "shape.shape_of"([[RHS]]) : (tensor<3x2x4xf32>) -> !shape.shape
|
||||
// CHECK: [[CM2:%.*]] = constant -2 : i32
|
||||
// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[LHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
|
||||
// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
|
||||
// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[RHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
|
||||
// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<3x4x4xf32>
|
||||
// CHECK: }
|
||||
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
return %0 : tensor<3x4x4xf32>
|
||||
}
|
||||
|
||||
func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_lhs_batch
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32>
|
||||
return %0 : tensor<3x4x4xf32>
|
||||
}
|
||||
|
||||
func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_rhs_batch
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
return %0 : tensor<3x4x4xf32>
|
||||
}
|
||||
|
||||
func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_dynamic
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_real
|
||||
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_complex(
|
||||
// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex<f32>>, [[RHS:%.*]]: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK: [[LHSRE:%.*]] = "xla_hlo.real"([[LHS]])
|
||||
// CHECK: [[LHSIM:%.*]] = "xla_hlo.imag"([[LHS]])
|
||||
// CHECK: [[LHSIMNEG:%.*]] = "xla_hlo.negate"([[LHSIM]])
|
||||
// CHECK: [[LHSCONJ:%.*]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]])
|
||||
// CHECK: [[RHSRE:%.*]] = "xla_hlo.real"([[RHS]])
|
||||
// CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]])
|
||||
// CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]])
|
||||
// CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]])
|
||||
// CHECK: "shape.shape_of"([[LHSCONJ]])
|
||||
// CHECK: "shape.shape_of"([[RHSCONJ]])
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
return %0 : tensor<5x4xcomplex<f32>>
|
||||
}
|
@ -4117,100 +4117,6 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.BatchMatMulV2 op legalizations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_broadcast_singleton_dimension
|
||||
func @batchmatmulv2_broadcast_singleton_dimension(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32>
|
||||
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32>
|
||||
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: return [[BDST]] : tensor<3x4x4xf32>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
return %0 : tensor<3x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_lhs_batch
|
||||
func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32>
|
||||
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<3x2x4xf32>
|
||||
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: return [[BDST]] : tensor<3x4x4xf32>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32>
|
||||
return %0 : tensor<3x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_rhs_batch
|
||||
func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
|
||||
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xf32>, {{.*}}) -> tensor<3x4x2xf32>
|
||||
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32>
|
||||
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: return [[BDST]] : tensor<3x4x4xf32>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
return %0 : tensor<3x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_dynamic
|
||||
func @batchmatmulv2_dynamic(%arg0: tensor<?x4x2xf32>, %arg1: tensor<?x2x4xf32>) -> tensor<?x4x4xf32> {
|
||||
// CHECK: "tf.BatchMatMulV2"
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<?x4x2xf32>, tensor<?x2x4xf32>) -> tensor<?x4x4xf32>
|
||||
return %0 : tensor<?x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_real
|
||||
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
|
||||
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>, {{.*}}) -> tensor<5x2xf32>
|
||||
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<2x4xf32>
|
||||
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||
// CHECK: return [[BDST]] : tensor<5x4xf32>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_complex
|
||||
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
|
||||
// CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
|
||||
// CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.negate"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32>
|
||||
// CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex<f32>>
|
||||
// CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
|
||||
// CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
|
||||
// CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.negate"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex<f32>>
|
||||
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex<f32>>, {{.*}}) -> tensor<5x2xcomplex<f32>>
|
||||
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex<f32>>, {{.*}}) -> tensor<2x4xcomplex<f32>>
|
||||
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
// CHECK: return [[BDST]] : tensor<5x4xcomplex<f32>>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
return %0 : tensor<5x4xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
|
||||
func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) {
|
||||
// CHECK: [[VAL_1:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x100xi32>
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Traits.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
@ -1869,29 +1870,63 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
||||
static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
|
||||
Value *out_lhs, Value *out_rhs,
|
||||
PatternRewriter *rewriter) {
|
||||
// The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
|
||||
// - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
|
||||
// - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
|
||||
// - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
|
||||
// To perform the matmul, we need to first broadcast lhs and rhs to a common
|
||||
// set of leading dimensions before doing the actual matmul.
|
||||
// That's what the code below does.
|
||||
// In particular, we populate out_lhs and out_rhs to have dimension structure:
|
||||
// - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
|
||||
// - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
|
||||
// To do this, we need to calculate those output shapes, which involves
|
||||
// slicing off the leading batch dims of each operand, broadcasting them,
|
||||
// then concatenating the broadcasted leading dims back to the row/col dims.
|
||||
// Finally, we create a TF::BroadcastTo op that does the actual broadcast.
|
||||
|
||||
// TODO(silvasean): Reduce duplication across reified shape calculations and
|
||||
// the static computation of output types needed to create ops.
|
||||
Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
|
||||
Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
|
||||
Value const_neg2 =
|
||||
rewriter->create<ConstantOp>(loc, rewriter->getI32IntegerAttr(-2));
|
||||
auto lhs_splitted =
|
||||
rewriter->create<shape::SplitAtOp>(loc, lhs_shape, const_neg2);
|
||||
auto rhs_splitted =
|
||||
rewriter->create<shape::SplitAtOp>(loc, rhs_shape, const_neg2);
|
||||
auto lhs_type = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhs_type = rhs.getType().cast<RankedTensorType>();
|
||||
// The last two dimensions are the matrix row/col dimensions. Don't
|
||||
// broadcast them.
|
||||
SmallVector<int64_t, 6> result_batch_shape;
|
||||
// The last two dimensions are the matrix row/col dimensions. Don't broadcast
|
||||
// them.
|
||||
SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
|
||||
OpTrait::util::getBroadcastedShape(lhs_type.getShape().drop_back(2),
|
||||
rhs_type.getShape().drop_back(2),
|
||||
result_batch_shape);
|
||||
auto handle_one_side = [rewriter, &result_batch_shape, loc](
|
||||
Value side, RankedTensorType type,
|
||||
Value *out_side) {
|
||||
result_batch_shape_compile_time_extents);
|
||||
auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
|
||||
loc, lhs_splitted.head(), rhs_splitted.head(),
|
||||
/*error=*/nullptr);
|
||||
// Lambda which handles the broadcasting of one side to the common
|
||||
// leading-batch dimensions.
|
||||
auto broadcast_one_side = [&](Value side, RankedTensorType type,
|
||||
Value tail_shape, Value *out_side) {
|
||||
ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
|
||||
auto result_shape = result_batch_shape;
|
||||
auto result_shape = result_batch_shape_compile_time_extents;
|
||||
result_shape.append(matrix_dims.begin(), matrix_dims.end());
|
||||
auto result_type =
|
||||
RankedTensorType::get(result_shape, type.getElementType());
|
||||
auto shape = rewriter->create<TF::ConstOp>(
|
||||
loc, GetI64ElementsAttr(result_shape, rewriter));
|
||||
*out_side =
|
||||
rewriter->create<TF::BroadcastToOp>(loc, result_type, side, shape);
|
||||
auto shape =
|
||||
rewriter->create<shape::ConcatOp>(loc, result_batch_shape, tail_shape);
|
||||
auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
|
||||
loc,
|
||||
RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
|
||||
rewriter->getIndexType()),
|
||||
shape);
|
||||
*out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
|
||||
shape_tensor);
|
||||
};
|
||||
handle_one_side(lhs, lhs_type, out_lhs);
|
||||
handle_one_side(rhs, rhs_type, out_rhs);
|
||||
broadcast_one_side(lhs, lhs_type, lhs_splitted.tail(), out_lhs);
|
||||
broadcast_one_side(rhs, rhs_type, rhs_splitted.tail(), out_rhs);
|
||||
}
|
||||
|
||||
class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
|
||||
@ -1911,10 +1946,6 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
|
||||
if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
|
||||
rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
|
||||
}
|
||||
// TODO(silvasean): Support dynamic shapes.
|
||||
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Broadcast both operands.
|
||||
BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
|
||||
@ -1935,6 +1966,8 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
|
||||
/*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
|
||||
/*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
|
||||
rewriter.getContext());
|
||||
// TODO(silvasean): Emit shape checks for contracting dimensions.
|
||||
// (The batch dimensions are checked by the broadcasting logic)
|
||||
rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
|
||||
dimension_numbers,
|
||||
/*precision_config=*/nullptr);
|
||||
@ -4753,6 +4786,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
target.addLegalOp<CallOp>();
|
||||
target.addLegalOp<TensorCastOp>();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user