diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index f30bd961fca..d8a1a156b0c 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -3755,3 +3755,40 @@ func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) return %0 : tensor } +// CHECK-LABEL: func @batchmatmulv2_adj_real +func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { + // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>) -> tensor<5x2xf32> + // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<2x4xf32> + // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { + // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + // CHECK: return [[BDST]] : tensor<5x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + return %0 : tensor<5x4xf32> +} + +// CHECK-LABEL: func @batchmatmulv2_adj_complex +func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { + // CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> + // CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex>) -> tensor<5x2xf32> + // CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.neg"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32> + // CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex> + // CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> + // CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex>) -> tensor<2x4xf32> + // CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.neg"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32> + // CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex> + // CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"([[LHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex>) -> tensor<5x2xcomplex> + // CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"([[RHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex>) -> tensor<2x4xcomplex> + // CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = { + // CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, + // CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>, + // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64> + // CHECK-SAME: }} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + // CHECK: return [[BDST]] : tensor<5x4xcomplex> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + return %0 : tensor<5x4xcomplex> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 817dfb55ec9..65704ca8dec 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -1629,17 +1629,17 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, PatternRewriter &rewriter) const override { - // TODO(silvasean): Handle adj_x/adj_y - // Should be able to just set the contracting_dimensions attribute - // appropriately. - // For complex types, need to do a complex conjugation. - if (op.adj_x() || op.adj_y()) return failure(); - Value lhs = op.x(); Value rhs = op.y(); auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) return failure(); + if (lhs_type.getElementType().isa() && op.adj_x()) { + lhs = rewriter.create(op.getLoc(), lhs_type, lhs); + } + if (rhs_type.getElementType().isa() && op.adj_y()) { + rhs = rewriter.create(op.getLoc(), rhs_type, rhs); + } // TODO(silvasean): Support dynamic shapes. if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) { return failure(); @@ -1654,10 +1654,10 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { int64_t rank = lhs_type.getRank(); auto batch_dimensions = GetI64ElementsAttr( llvm::to_vector<4>(llvm::seq(0, rank - 2)), &rewriter); - auto lhs_contracting_dimensions = - GetI64ElementsAttr(llvm::makeArrayRef({rank - 1}), &rewriter); - auto rhs_contracting_dimensions = - GetI64ElementsAttr(llvm::makeArrayRef({rank - 2}), &rewriter); + auto lhs_contracting_dimensions = GetI64ElementsAttr( + llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter); + auto rhs_contracting_dimensions = GetI64ElementsAttr( + llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter); auto dimension_numbers = DotDimensionNumbers::get( /*lhs_batching_dimensions=*/batch_dimensions, /*rhs_batching_dimensions=*/batch_dimensions,