[TF/MLIR] Supports BatchMatMulV2 in fold-broadcast pass.
PiperOrigin-RevId: 357261459 Change-Id: If98fe99d9c864cde1d4deb935087dbfef7924257
This commit is contained in:
parent
ab9516dd9a
commit
9421ecff2a
@ -72,3 +72,33 @@ func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> te
|
|||||||
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
|
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
|
||||||
// CHECK: %[[V0]] : tensor<5x7xf32>
|
// CHECK: %[[V0]] : tensor<5x7xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_batch_matmul_v2_rhs
|
||||||
|
func @broadcast_batch_matmul_v2_rhs(%arg0: tensor<17x17x17xf32>, %arg1: tensor<17x24xf32>) -> tensor<17x17x24xf32> {
|
||||||
|
%cst = constant dense<[17, 17, 24]> : tensor<3xi64>
|
||||||
|
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<17x24xf32>, tensor<3xi64>) -> tensor<17x17x24xf32>
|
||||||
|
%1 = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||||
|
return %1 : tensor<17x17x24xf32>
|
||||||
|
// CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x24xf32>) -> tensor<17x17x24xf32>
|
||||||
|
// CHECK: %[[V0]] : tensor<17x17x24xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_batch_matmul_v2_lhs
|
||||||
|
func @broadcast_batch_matmul_v2_lhs(%arg0: tensor<17x17xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
|
||||||
|
%cst = constant dense<[17, 17, 17]> : tensor<3xi64>
|
||||||
|
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
|
||||||
|
%1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||||
|
return %1 : tensor<17x17x24xf32>
|
||||||
|
// CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||||
|
// CHECK: %[[V0]] : tensor<17x17x24xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @broadcast_batch_matmul_v2_failed
|
||||||
|
func @broadcast_batch_matmul_v2_failed(%arg0: tensor<17x17x1xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
|
||||||
|
%cst = constant dense<[17, 17, 17]> : tensor<3xi64>
|
||||||
|
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17x1xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
|
||||||
|
%1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||||
|
return %1 : tensor<17x17x24xf32>
|
||||||
|
// CHECK: %[[V0:.*]] = "tf.BroadcastTo"
|
||||||
|
// CHECK: "tf.BatchMatMulV2"(%[[V0]], %arg1)
|
||||||
|
}
|
||||||
|
@ -44,7 +44,14 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
|
|||||||
template <typename Op>
|
template <typename Op>
|
||||||
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
|
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
|
||||||
|
|
||||||
LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
|
LogicalResult RewriteOp(
|
||||||
|
Operation* op, PatternRewriter& rewriter,
|
||||||
|
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
|
||||||
|
SmallVectorImpl<int64_t>&)>&
|
||||||
|
get_broadcasted_shape) const;
|
||||||
|
|
||||||
|
LogicalResult RewriteBatchMatMulV2Op(Operation* op,
|
||||||
|
PatternRewriter& rewriter) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
||||||
@ -55,26 +62,78 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
|||||||
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
|
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
|
||||||
return RewriteOp(op, rewriter);
|
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
|
||||||
|
|
||||||
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
|
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
|
||||||
// incompatible_shape_error is `true` (what is also checked by the verifier).
|
// incompatible_shape_error is `true` (what is also checked by the verifier).
|
||||||
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
|
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
|
||||||
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
|
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
|
||||||
|
if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
|
||||||
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
|
auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
|
||||||
|
if (!matmul_op) return failure();
|
||||||
|
|
||||||
|
// Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
|
||||||
|
// shape of op's first/left-hand-side operand and `shape_y` is the shape of
|
||||||
|
// op's second/right-hand-side operand.
|
||||||
|
const auto get_broadcasted_shape =
|
||||||
|
[&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
|
||||||
|
SmallVectorImpl<int64_t>& result_shape) {
|
||||||
|
if (shape_x.size() < 2 || shape_y.size() < 2) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks outer dimensions (i.e., the dimensions higher than 2D) are
|
||||||
|
// broadcastable. If true, then get the broadcasted shape for outer
|
||||||
|
// dimension.
|
||||||
|
if (!OpTrait::util::getBroadcastedShape(
|
||||||
|
shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int x_row =
|
||||||
|
matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
|
||||||
|
const int x_col =
|
||||||
|
!matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
|
||||||
|
|
||||||
|
const int y_row =
|
||||||
|
matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
|
||||||
|
const int y_col =
|
||||||
|
!matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
|
||||||
|
|
||||||
|
// Checks that matrix multiply can perform a valid contraction.
|
||||||
|
if (x_col != y_row) {
|
||||||
|
result_shape.clear();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
result_shape.push_back(x_row);
|
||||||
|
result_shape.push_back(y_col);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
return RewriteOp(op, rewriter, get_broadcasted_shape);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Op>
|
template <typename Op>
|
||||||
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
|
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter) const {
|
||||||
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
|
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
|
||||||
if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
|
if (eq_op && eq_op.incompatible_shape_error())
|
||||||
|
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
|
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
|
||||||
Operation* op, PatternRewriter& rewriter) const {
|
Operation* op, PatternRewriter& rewriter,
|
||||||
|
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
|
||||||
|
SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
|
||||||
|
const {
|
||||||
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
|
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -102,12 +161,16 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
|
|||||||
.dyn_cast_or_null<RankedTensorType>();
|
.dyn_cast_or_null<RankedTensorType>();
|
||||||
if (!argument_type || !argument_type.hasStaticShape()) continue;
|
if (!argument_type || !argument_type.hasStaticShape()) continue;
|
||||||
|
|
||||||
|
// Get the unbroadcasted shapes in the operand order.
|
||||||
|
std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
|
||||||
|
operand_shapes[i] = broadcast_arg_type.getShape();
|
||||||
|
operand_shapes[1 - i] = argument_type.getShape();
|
||||||
|
|
||||||
// Check that the input of the broadcast and the other operand is broadcast
|
// Check that the input of the broadcast and the other operand is broadcast
|
||||||
// compatible.
|
// compatible.
|
||||||
llvm::SmallVector<int64_t, 4> broadcasted_shape;
|
llvm::SmallVector<int64_t, 4> broadcasted_shape;
|
||||||
if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(),
|
if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
|
||||||
argument_type.getShape(),
|
broadcasted_shape))
|
||||||
broadcasted_shape))
|
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Check that an implicit broadcast between the operand of the broadcast and
|
// Check that an implicit broadcast between the operand of the broadcast and
|
||||||
|
Loading…
Reference in New Issue
Block a user