TF2XLA: BatchMatMulV2: add adj_x/adj_y support
PiperOrigin-RevId: 302565809 Change-Id: Ib325e819e7ce913bc59deed6aedc5a29e0a28344
This commit is contained in:
parent
5d93f28897
commit
60d6ea479e
@ -3755,3 +3755,40 @@ func @batchmatmulv2_dynamic(%arg0: tensor<?x4x2xf32>, %arg1: tensor<?x2x4xf32>)
|
||||
return %0 : tensor<?x4x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_real
|
||||
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
|
||||
// CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>) -> tensor<5x2xf32>
|
||||
// CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||
// CHECK: return [[BDST]] : tensor<5x4xf32>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_complex
|
||||
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
|
||||
// CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
|
||||
// CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.neg"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32>
|
||||
// CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex<f32>>
|
||||
// CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
|
||||
// CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
|
||||
// CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.neg"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex<f32>>
|
||||
// CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"([[LHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xcomplex<f32>>
|
||||
// CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"([[RHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xcomplex<f32>>
|
||||
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: }} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
// CHECK: return [[BDST]] : tensor<5x4xcomplex<f32>>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
return %0 : tensor<5x4xcomplex<f32>>
|
||||
}
|
||||
|
@ -1629,17 +1629,17 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
|
||||
|
||||
LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// TODO(silvasean): Handle adj_x/adj_y
|
||||
// Should be able to just set the contracting_dimensions attribute
|
||||
// appropriately.
|
||||
// For complex types, need to do a complex conjugation.
|
||||
if (op.adj_x() || op.adj_y()) return failure();
|
||||
|
||||
Value lhs = op.x();
|
||||
Value rhs = op.y();
|
||||
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
if (!lhs_type || !rhs_type) return failure();
|
||||
if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
|
||||
lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
|
||||
}
|
||||
if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
|
||||
rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
|
||||
}
|
||||
// TODO(silvasean): Support dynamic shapes.
|
||||
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) {
|
||||
return failure();
|
||||
@ -1654,10 +1654,10 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
|
||||
int64_t rank = lhs_type.getRank();
|
||||
auto batch_dimensions = GetI64ElementsAttr(
|
||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
|
||||
auto lhs_contracting_dimensions =
|
||||
GetI64ElementsAttr(llvm::makeArrayRef({rank - 1}), &rewriter);
|
||||
auto rhs_contracting_dimensions =
|
||||
GetI64ElementsAttr(llvm::makeArrayRef({rank - 2}), &rewriter);
|
||||
auto lhs_contracting_dimensions = GetI64ElementsAttr(
|
||||
llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter);
|
||||
auto rhs_contracting_dimensions = GetI64ElementsAttr(
|
||||
llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter);
|
||||
auto dimension_numbers = DotDimensionNumbers::get(
|
||||
/*lhs_batching_dimensions=*/batch_dimensions,
|
||||
/*rhs_batching_dimensions=*/batch_dimensions,
|
||||
|
Loading…
x
Reference in New Issue
Block a user