TF2XLA: BatchMatMulV2: add adj_x/adj_y support
PiperOrigin-RevId: 302565809 Change-Id: Ib325e819e7ce913bc59deed6aedc5a29e0a28344
This commit is contained in:
parent
5d93f28897
commit
60d6ea479e
@ -3755,3 +3755,40 @@ func @batchmatmulv2_dynamic(%arg0: tensor<?x4x2xf32>, %arg1: tensor<?x2x4xf32>)
|
|||||||
return %0 : tensor<?x4x4xf32>
|
return %0 : tensor<?x4x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @batchmatmulv2_adj_real
|
||||||
|
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
|
||||||
|
// CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xf32>) -> tensor<5x2xf32>
|
||||||
|
// CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||||
|
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||||
|
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||||
|
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
// CHECK-SAME: }} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||||
|
// CHECK: return [[BDST]] : tensor<5x4xf32>
|
||||||
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||||
|
return %0 : tensor<5x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @batchmatmulv2_adj_complex
|
||||||
|
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||||
|
// CHECK: [[LHSRE:%.+]] = "xla_hlo.real"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
|
||||||
|
// CHECK: [[LHSIM:%.+]] = "xla_hlo.imag"(%arg0) : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xf32>
|
||||||
|
// CHECK: [[LHSIMNEG:%.+]] = "xla_hlo.neg"([[LHSIM]]) : (tensor<5x2xf32>) -> tensor<5x2xf32>
|
||||||
|
// CHECK: [[LHSCONJ:%.+]] = "xla_hlo.complex"([[LHSRE]], [[LHSIMNEG]]) : (tensor<5x2xf32>, tensor<5x2xf32>) -> tensor<5x2xcomplex<f32>>
|
||||||
|
// CHECK: [[RHSRE:%.+]] = "xla_hlo.real"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: [[RHSIM:%.+]] = "xla_hlo.imag"(%arg1) : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: [[RHSIMNEG:%.+]] = "xla_hlo.neg"([[RHSIM]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: [[RHSCONJ:%.+]] = "xla_hlo.complex"([[RHSRE]], [[RHSIMNEG]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xcomplex<f32>>
|
||||||
|
// CHECK: [[BLHS:%.+]] = "xla_hlo.broadcast_in_dim"([[LHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<5x2xcomplex<f32>>) -> tensor<5x2xcomplex<f32>>
|
||||||
|
// CHECK: [[BRHS:%.+]] = "xla_hlo.broadcast_in_dim"([[RHSCONJ]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x4xcomplex<f32>>) -> tensor<2x4xcomplex<f32>>
|
||||||
|
// CHECK: [[BDST:%.+]] = "xla_hlo.dot_general"([[BLHS]], [[BRHS]]) {dot_dimension_numbers = {
|
||||||
|
// CHECK-SAME: lhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||||
|
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||||
|
// CHECK-SAME: rhs_batching_dimensions = dense<[]> : tensor<0xi64>,
|
||||||
|
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
// CHECK-SAME: }} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||||
|
// CHECK: return [[BDST]] : tensor<5x4xcomplex<f32>>
|
||||||
|
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||||
|
return %0 : tensor<5x4xcomplex<f32>>
|
||||||
|
}
|
||||||
|
@ -1629,17 +1629,17 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
|
LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// TODO(silvasean): Handle adj_x/adj_y
|
|
||||||
// Should be able to just set the contracting_dimensions attribute
|
|
||||||
// appropriately.
|
|
||||||
// For complex types, need to do a complex conjugation.
|
|
||||||
if (op.adj_x() || op.adj_y()) return failure();
|
|
||||||
|
|
||||||
Value lhs = op.x();
|
Value lhs = op.x();
|
||||||
Value rhs = op.y();
|
Value rhs = op.y();
|
||||||
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!lhs_type || !rhs_type) return failure();
|
if (!lhs_type || !rhs_type) return failure();
|
||||||
|
if (lhs_type.getElementType().isa<ComplexType>() && op.adj_x()) {
|
||||||
|
lhs = rewriter.create<TF::ConjOp>(op.getLoc(), lhs_type, lhs);
|
||||||
|
}
|
||||||
|
if (rhs_type.getElementType().isa<ComplexType>() && op.adj_y()) {
|
||||||
|
rhs = rewriter.create<TF::ConjOp>(op.getLoc(), rhs_type, rhs);
|
||||||
|
}
|
||||||
// TODO(silvasean): Support dynamic shapes.
|
// TODO(silvasean): Support dynamic shapes.
|
||||||
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) {
|
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) {
|
||||||
return failure();
|
return failure();
|
||||||
@ -1654,10 +1654,10 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern<TF::BatchMatMulV2Op> {
|
|||||||
int64_t rank = lhs_type.getRank();
|
int64_t rank = lhs_type.getRank();
|
||||||
auto batch_dimensions = GetI64ElementsAttr(
|
auto batch_dimensions = GetI64ElementsAttr(
|
||||||
llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
|
llvm::to_vector<4>(llvm::seq<int64_t>(0, rank - 2)), &rewriter);
|
||||||
auto lhs_contracting_dimensions =
|
auto lhs_contracting_dimensions = GetI64ElementsAttr(
|
||||||
GetI64ElementsAttr(llvm::makeArrayRef({rank - 1}), &rewriter);
|
llvm::makeArrayRef({op.adj_x() ? rank - 2 : rank - 1}), &rewriter);
|
||||||
auto rhs_contracting_dimensions =
|
auto rhs_contracting_dimensions = GetI64ElementsAttr(
|
||||||
GetI64ElementsAttr(llvm::makeArrayRef({rank - 2}), &rewriter);
|
llvm::makeArrayRef({op.adj_y() ? rank - 1 : rank - 2}), &rewriter);
|
||||||
auto dimension_numbers = DotDimensionNumbers::get(
|
auto dimension_numbers = DotDimensionNumbers::get(
|
||||||
/*lhs_batching_dimensions=*/batch_dimensions,
|
/*lhs_batching_dimensions=*/batch_dimensions,
|
||||||
/*rhs_batching_dimensions=*/batch_dimensions,
|
/*rhs_batching_dimensions=*/batch_dimensions,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user