MLIR TF2XLA bridge: BatchMatMulV2 lowering for dynamic shapes

PiperOrigin-RevId: 308722335
Change-Id: I070f6eafbcb49faa6183817f675c6674199bdad5
This commit is contained in:
Sean Silva 2020-04-27 16:31:50 -07:00 committed by TensorFlower Gardener
parent b1c4cf2183
commit 7cd479b8fc
4 changed files with 147 additions and 112 deletions

View File

@ -146,6 +146,7 @@ cc_library(
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",

View File

@ -0,0 +1,93 @@
// RUN: tf-opt -xla-legalize-tf=allow-partial-conversion %s | FileCheck %s --dump-input-on-failure
//===----------------------------------------------------------------------===//
// tf.BatchMatMulV2 op legalizations.
//===----------------------------------------------------------------------===//
func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_basic
// CHECK-SAME: ([[LHS:%.*]]: tensor<1x4x2xf32>, [[RHS:%.*]]: tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: [[LHSSHAPE:%.*]] = "shape.shape_of"([[LHS]]) : (tensor<1x4x2xf32>) -> !shape.shape
// CHECK: [[RHSSHAPE:%.*]] = "shape.shape_of"([[RHS]]) : (tensor<3x2x4xf32>) -> !shape.shape
// CHECK: [[CM2:%.*]] = constant -2 : i32
// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[LHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[RHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[RESULT]] : tensor<3x4x4xf32>
// CHECK: }
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
return %0 : tensor<3x4x4xf32>
}
func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_lhs_batch
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32>
return %0 : tensor<3x4x4xf32>
}
func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_rhs_batch
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}
// CHECK: "xla_hlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
return %0 : tensor<3x4x4xf32>
}
func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK-LABEL: func @batchmatmulv2_dynamic
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_adj_real
// CHECK: "xla_hlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK-LABEL: func @batchmatmulv2_adj_complex(
// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex<f32>>, [[RHS:%.*]]: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK: [[LHSRE:%.*]] = "xla_hlo.real"([[LHS]])
// CHECK: [[LHSIM:%.*]] = "xla_hlo.imag"([[LHS]])
// CHECK: [[LHSIMNEG:%.*]] = "xla_hlo.negate"([[LHSIM]])
// CHECK: [[LHSCONJ:%.*]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]])
// CHECK: [[RHSRE:%.*]] = "xla_hlo.real"([[RHS]])
// CHECK: [[RHSIM:%.*]] = "xla_hlo.imag"([[RHS]])
// CHECK: [[RHSIMNEG:%.*]] = "xla_hlo.negate"([[RHSIM]])
// CHECK: [[RHSCONJ:%.*]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]])
// CHECK: "shape.shape_of"([[LHSCONJ]])
// CHECK: "shape.shape_of"([[RHSCONJ]])
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
return %0 : tensor<5x4xcomplex<f32>>
}

View File

@ -4117,100 +4117,6 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
//===----------------------------------------------------------------------===//
// tf.BatchMatMulV2 op legalizations.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @batchmatmulv2_broadcast_singleton_dimension
func @batchmatmulv2_broadcast_singleton_dimension(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32>
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32>
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[BDST]] : tensor<3x4x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<1x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
return %0 : tensor<3x4x4xf32>
}
// CHECK-LABEL: func @batchmatmulv2_lhs_batch
func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x4x2xf32>, {{.*}}) -> tensor<3x4x2xf32>
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<3x2x4xf32>
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[BDST]] : tensor<3x4x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<3x4x2xf32>, tensor<2x4xf32>) -> tensor<3x4x4xf32>
return %0 : tensor<3x4x4xf32>
}
// CHECK-LABEL: func @batchmatmulv2_rhs_batch
func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> {
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x2xf32>, {{.*}}) -> tensor<3x4x2xf32>
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, {{.*}}) -> tensor<3x2x4xf32>
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: }} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[BDST]] : tensor<3x4x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
return %0 : tensor<3x4x4xf32>
}
// CHECK-LABEL: func @batchmatmulv2_dynamic
func @batchmatmulv2_dynamic(%arg0: tensor<?x4x2xf32>, %arg1: tensor<?x2x4xf32>) -> tensor<?x4x4xf32> {
// CHECK: "tf.BatchMatMulV2"
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {T = f32, adj_x = false, adj_y = false, device = ""} : (tensor<?x4x2xf32>, tensor<?x2x4xf32>) -> tensor<?x4x4xf32>
return %0 : tensor<?x4x4xf32>
}
// CHECK-LABEL: func @batchmatmulv2_adj_real
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>, {{.*}}) -> tensor<5x2xf32>
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>, {{.*}}) -> tensor<2x4xf32>
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
// CHECK: return [[BDST]] : tensor<5x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: func @batchmatmulv2_adj_complex
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
// CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
// CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.negate"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32>
// CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex<f32>>
// CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
// CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
// CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.negate"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex<f32>>
// CHECK: [[BLHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex<f32>>, {{.*}}) -> tensor<5x2xcomplex<f32>>
// CHECK: [[BRHS:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHSCONJ]], {{.*}}) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex<f32>>, {{.*}}) -> tensor<2x4xcomplex<f32>>
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: }} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
// CHECK: return [[BDST]] : tensor<5x4xcomplex<f32>>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
return %0 : tensor<5x4xcomplex<f32>>
}
// CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) {
// CHECK: [[VAL_1:%.*]] = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<100x100xi32>

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
@ -1869,29 +1870,63 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc,
Value *out_lhs, Value *out_rhs,
PatternRewriter *rewriter) {
// The dimension structure of the relevant operands to a tf.BatchMatMulV2 is:
// - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS]
// - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS]
// - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS]
// To perform the matmul, we need to first broadcast lhs and rhs to a common
// set of leading dimensions before doing the actual matmul.
// That's what the code below does.
// In particular, we populate out_lhs and out_rhs to have dimension structure:
// - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS]
// - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS]
// To do this, we need to calculate those output shapes, which involves
// slicing off the leading batch dims of each operand, broadcasting them,
// then concatenating the broadcasted leading dims back to the row/col dims.
// Finally, we create a TF::BroadcastTo op that does the actual broadcast.
// TODO(silvasean): Reduce duplication across reified shape calculations and
// the static computation of output types needed to create ops.
Value lhs_shape = rewriter->create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = rewriter->create<shape::ShapeOfOp>(loc, rhs);
Value const_neg2 =
rewriter->create<ConstantOp>(loc, rewriter->getI32IntegerAttr(-2));
auto lhs_splitted =
rewriter->create<shape::SplitAtOp>(loc, lhs_shape, const_neg2);
auto rhs_splitted =
rewriter->create<shape::SplitAtOp>(loc, rhs_shape, const_neg2);
auto lhs_type = lhs.getType().cast<RankedTensorType>();
auto rhs_type = rhs.getType().cast<RankedTensorType>();
// The last two dimensions are the matrix row/col dimensions. Don't
// broadcast them.
SmallVector<int64_t, 6> result_batch_shape;
// The last two dimensions are the matrix row/col dimensions. Don't broadcast
// them.
SmallVector<int64_t, 6> result_batch_shape_compile_time_extents;
OpTrait::util::getBroadcastedShape(lhs_type.getShape().drop_back(2),
rhs_type.getShape().drop_back(2),
result_batch_shape);
auto handle_one_side = [rewriter, &result_batch_shape, loc](
Value side, RankedTensorType type,
Value *out_side) {
result_batch_shape_compile_time_extents);
auto result_batch_shape = rewriter->create<shape::BroadcastOp>(
loc, lhs_splitted.head(), rhs_splitted.head(),
/*error=*/nullptr);
// Lambda which handles the broadcasting of one side to the common
// leading-batch dimensions.
auto broadcast_one_side = [&](Value side, RankedTensorType type,
Value tail_shape, Value *out_side) {
ArrayRef<int64_t> matrix_dims = type.getShape().take_back(2);
auto result_shape = result_batch_shape;
auto result_shape = result_batch_shape_compile_time_extents;
result_shape.append(matrix_dims.begin(), matrix_dims.end());
auto result_type =
RankedTensorType::get(result_shape, type.getElementType());
auto shape = rewriter->create<TF::ConstOp>(
loc, GetI64ElementsAttr(result_shape, rewriter));
*out_side =
rewriter->create<TF::BroadcastToOp>(loc, result_type, side, shape);
auto shape =
rewriter->create<shape::ConcatOp>(loc, result_batch_shape, tail_shape);
auto shape_tensor = rewriter->create<shape::ToExtentTensorOp>(
loc,
RankedTensorType::get({static_cast<int64_t>(result_shape.size())},
rewriter->getIndexType()),
shape);
*out_side = rewriter->create<TF::BroadcastToOp>(loc, result_type, side,
shape_tensor);
};
handle_one_side(lhs, lhs_type, out_lhs);
handle_one_side(rhs, rhs_type, out_rhs);
broadcast_one_side(lhs, lhs_type, lhs_splitted.tail(), out_lhs);
broadcast_one_side(rhs, rhs_type, rhs_splitted.tail(), out_rhs);
}
class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
@ -1911,10 +1946,6 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
}
// TODO(silvasean): Support dynamic shapes.
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) {
return failure();
}
// Broadcast both operands.
BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs,
@ -1935,6 +1966,8 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
/*lhs_contracting_dimensions=*/lhs_contracting_dimensions,
/*rhs_contracting_dimensions=*/rhs_contracting_dimensions,
rewriter.getContext());
// TODO(silvasean): Emit shape checks for contracting dimensions.
// (The batch dimensions are checked by the broadcasting logic)
rewriter.replaceOpWithNewOp<DotGeneralOp>(op, op.getType(), lhs, rhs,
dimension_numbers,
/*precision_config=*/nullptr);
@ -4753,6 +4786,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
ConversionTarget target(*context);
target.addLegalDialect<XlaHloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<shape::ShapeDialect>();
target.addLegalOp<CallOp>();
target.addLegalOp<TensorCastOp>();