TF2XLA: BatchMatMulV2: add adj_x/adj_y support

PiperOrigin-RevId: 302565809
Change-Id: Ib325e819e7ce913bc59deed6aedc5a29e0a28344
This commit is contained in:
Sean Silva 2020-03-23 18:36:05 -07:00 committed by TensorFlower Gardener
parent 5d93f28897
commit 60d6ea479e
2 changed files with 47 additions and 10 deletions

View File

@ -3755,3 +3755,40 @@ func @batchmatmulv2_dynamic(%arg0: tensor<?x4x2xf32>, %arg1: tensor<?x2x4xf32>)
return %0 : tensor<?x4x4xf32>
}
// CHECK-LABEL: func @batchmatmulv2_adj_real
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
// CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>) -> tensor<5x2xf32>
// CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
// CHECK: return [[BDST]] : tensor<5x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: func @batchmatmulv2_adj_complex
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
// CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
// CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.neg"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32>
// CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex<f32>>
// CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
// CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
// CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.neg"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex<f32>>
// CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"([[LHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xcomplex<f32>>
// CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"([[RHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xcomplex<f32>>
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: }} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
// CHECK: return [[BDST]] : tensor<5x4xcomplex<f32>>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
return %0 : tensor<5x4xcomplex<f32>>
}

View File

@ -1629,17 +1629,17 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
PatternRewriter &rewriter) const override {
// TODO(silvasean): Handle adj_x/adj_y
// Should be able to just set the contracting_dimensions attribute
// appropriately.
// For complex types, need to do a complex conjugation.
if (op.adj_x() || op.adj_y()) return failure();
Value lhs = op.x();
Value rhs = op.y();
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) return failure();
if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
}
if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
}
// TODO(silvasean): Support dynamic shapes.
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) {
return failure();
@ -1654,10 +1654,10 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
int64_t rank = lhs_type.getRank();
auto batch_dimensions = GetI64ElementsAttr(
llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
auto lhs_contracting_dimensions =
GetI64ElementsAttr(llvm::makeArrayRef({rank - 1}), &rewriter);
auto rhs_contracting_dimensions =
GetI64ElementsAttr(llvm::makeArrayRef({rank - 2}), &rewriter);
auto lhs_contracting_dimensions = GetI64ElementsAttr(
llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter);
auto rhs_contracting_dimensions = GetI64ElementsAttr(
llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter);
auto dimension_numbers = DotDimensionNumbers::get(
/*lhs_batching_dimensions=*/batch_dimensions,
/*rhs_batching_dimensions=*/batch_dimensions,