[TF/MLIR] Supports BatchMatMulV2 in fold-broadcast pass.
PiperOrigin-RevId: 357261459 Change-Id: If98fe99d9c864cde1d4deb935087dbfef7924257
This commit is contained in:
parent
ab9516dd9a
commit
9421ecff2a
@ -72,3 +72,33 @@ func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> te
|
||||
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
|
||||
// CHECK: %[[V0]] : tensor<5x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_batch_matmul_v2_rhs
|
||||
func @broadcast_batch_matmul_v2_rhs(%arg0: tensor<17x17x17xf32>, %arg1: tensor<17x24xf32>) -> tensor<17x17x24xf32> {
|
||||
%cst = constant dense<[17, 17, 24]> : tensor<3xi64>
|
||||
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<17x24xf32>, tensor<3xi64>) -> tensor<17x17x24xf32>
|
||||
%1 = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||
return %1 : tensor<17x17x24xf32>
|
||||
// CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x24xf32>) -> tensor<17x17x24xf32>
|
||||
// CHECK: %[[V0]] : tensor<17x17x24xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_batch_matmul_v2_lhs
|
||||
func @broadcast_batch_matmul_v2_lhs(%arg0: tensor<17x17xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
|
||||
%cst = constant dense<[17, 17, 17]> : tensor<3xi64>
|
||||
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
|
||||
%1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||
return %1 : tensor<17x17x24xf32>
|
||||
// CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||
// CHECK: %[[V0]] : tensor<17x17x24xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_batch_matmul_v2_failed
|
||||
func @broadcast_batch_matmul_v2_failed(%arg0: tensor<17x17x1xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
|
||||
%cst = constant dense<[17, 17, 17]> : tensor<3xi64>
|
||||
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17x1xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
|
||||
%1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
|
||||
return %1 : tensor<17x17x24xf32>
|
||||
// CHECK: %[[V0:.*]] = "tf.BroadcastTo"
|
||||
// CHECK: "tf.BatchMatMulV2"(%[[V0]], %arg1)
|
||||
}
|
||||
|
@ -44,7 +44,14 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
|
||||
template <typename Op>
|
||||
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
|
||||
|
||||
LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
|
||||
LogicalResult RewriteOp(
|
||||
Operation* op, PatternRewriter& rewriter,
|
||||
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
|
||||
SmallVectorImpl<int64_t>&)>&
|
||||
get_broadcasted_shape) const;
|
||||
|
||||
LogicalResult RewriteBatchMatMulV2Op(Operation* op,
|
||||
PatternRewriter& rewriter) const;
|
||||
};
|
||||
|
||||
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
||||
@ -55,26 +62,78 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
|
||||
return RewriteOp(op, rewriter);
|
||||
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
|
||||
|
||||
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
|
||||
// incompatible_shape_error is `true` (what is also checked by the verifier).
|
||||
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
|
||||
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
|
||||
if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
|
||||
if (!matmul_op) return failure();
|
||||
|
||||
// Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
|
||||
// shape of op's first/left-hand-side operand and `shape_y` is the shape of
|
||||
// op's second/right-hand-side operand.
|
||||
const auto get_broadcasted_shape =
|
||||
[&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
|
||||
SmallVectorImpl<int64_t>& result_shape) {
|
||||
if (shape_x.size() < 2 || shape_y.size() < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Checks outer dimensions (i.e., the dimensions higher than 2D) are
|
||||
// broadcastable. If true, then get the broadcasted shape for outer
|
||||
// dimension.
|
||||
if (!OpTrait::util::getBroadcastedShape(
|
||||
shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int x_row =
|
||||
matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
|
||||
const int x_col =
|
||||
!matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
|
||||
|
||||
const int y_row =
|
||||
matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
|
||||
const int y_col =
|
||||
!matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
|
||||
|
||||
// Checks that matrix multiply can perform a valid contraction.
|
||||
if (x_col != y_row) {
|
||||
result_shape.clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
result_shape.push_back(x_row);
|
||||
result_shape.push_back(y_col);
|
||||
return true;
|
||||
};
|
||||
|
||||
return RewriteOp(op, rewriter, get_broadcasted_shape);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
|
||||
if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
|
||||
if (eq_op && eq_op.incompatible_shape_error())
|
||||
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
Operation* op, PatternRewriter& rewriter,
|
||||
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
|
||||
SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
|
||||
const {
|
||||
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
|
||||
return failure();
|
||||
|
||||
@ -102,12 +161,16 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
|
||||
.dyn_cast_or_null<RankedTensorType>();
|
||||
if (!argument_type || !argument_type.hasStaticShape()) continue;
|
||||
|
||||
// Get the unbroadcasted shapes in the operand order.
|
||||
std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
|
||||
operand_shapes[i] = broadcast_arg_type.getShape();
|
||||
operand_shapes[1 - i] = argument_type.getShape();
|
||||
|
||||
// Check that the input of the broadcast and the other operand is broadcast
|
||||
// compatible.
|
||||
llvm::SmallVector<int64_t, 4> broadcasted_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(),
|
||||
argument_type.getShape(),
|
||||
broadcasted_shape))
|
||||
if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
|
||||
broadcasted_shape))
|
||||
continue;
|
||||
|
||||
// Check that an implicit broadcast between the operand of the broadcast and
|
||||
|
Loading…
Reference in New Issue
Block a user