[TF/MLIR] Supports BatchMatMulV2 in fold-broadcast pass.

PiperOrigin-RevId: 357261459
Change-Id: If98fe99d9c864cde1d4deb935087dbfef7924257
This commit is contained in:
A. Unique TensorFlower 2021-02-12 13:50:52 -08:00 committed by TensorFlower Gardener
parent ab9516dd9a
commit 9421ecff2a
2 changed files with 100 additions and 7 deletions

View File

@ -72,3 +72,33 @@ func @broadcast_both_operand(%arg0: tensor<7xf32>, %arg1: tensor<5x1xf32>) -> te
// CHECK: %[[V0:.*]] = "tf.Add"(%arg0, %arg1) : (tensor<7xf32>, tensor<5x1xf32>) -> tensor<5x7xf32>
// CHECK: %[[V0]] : tensor<5x7xf32>
}
// CHECK-LABEL: @broadcast_batch_matmul_v2_rhs
func @broadcast_batch_matmul_v2_rhs(%arg0: tensor<17x17x17xf32>, %arg1: tensor<17x24xf32>) -> tensor<17x17x24xf32> {
%cst = constant dense<[17, 17, 24]> : tensor<3xi64>
%0 = "tf.BroadcastTo"(%arg1, %cst) : (tensor<17x24xf32>, tensor<3xi64>) -> tensor<17x17x24xf32>
%1 = "tf.BatchMatMulV2"(%arg0, %0) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
return %1 : tensor<17x17x24xf32>
// CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x24xf32>) -> tensor<17x17x24xf32>
// CHECK: %[[V0]] : tensor<17x17x24xf32>
}
// CHECK-LABEL: @broadcast_batch_matmul_v2_lhs
func @broadcast_batch_matmul_v2_lhs(%arg0: tensor<17x17xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
%cst = constant dense<[17, 17, 17]> : tensor<3xi64>
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
%1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
return %1 : tensor<17x17x24xf32>
// CHECK: %[[V0:.*]] = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
// CHECK: %[[V0]] : tensor<17x17x24xf32>
}
// CHECK-LABEL: @broadcast_batch_matmul_v2_failed
func @broadcast_batch_matmul_v2_failed(%arg0: tensor<17x17x1xf32>, %arg1: tensor<17x17x24xf32>) -> tensor<17x17x24xf32> {
%cst = constant dense<[17, 17, 17]> : tensor<3xi64>
%0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<17x17x1xf32>, tensor<3xi64>) -> tensor<17x17x17xf32>
%1 = "tf.BatchMatMulV2"(%0, %arg1) {adj_x = false, adj_y = false} : (tensor<17x17x17xf32>, tensor<17x17x24xf32>) -> tensor<17x17x24xf32>
return %1 : tensor<17x17x24xf32>
// CHECK: %[[V0:.*]] = "tf.BroadcastTo"
// CHECK: "tf.BatchMatMulV2"(%[[V0]], %arg1)
}

View File

@ -44,7 +44,14 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern {
template <typename Op>
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter) const;
LogicalResult RewriteOp(
Operation* op, PatternRewriter& rewriter,
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
SmallVectorImpl<int64_t>&)>&
get_broadcasted_shape) const;
LogicalResult RewriteBatchMatMulV2Op(Operation* op,
PatternRewriter& rewriter) const;
};
class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
@ -55,26 +62,78 @@ class BroadcastFoldPass : public PassWrapper<BroadcastFoldPass, FunctionPass> {
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
return RewriteOp(op, rewriter);
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
// incompatible_shape_error is `true` (what is also checked by the verifier).
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
return failure();
}
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
Operation* op, PatternRewriter& rewriter) const {
auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
if (!matmul_op) return failure();
// Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
// shape of op's first/left-hand-side operand and `shape_y` is the shape of
// op's second/right-hand-side operand.
const auto get_broadcasted_shape =
[&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
SmallVectorImpl<int64_t>& result_shape) {
if (shape_x.size() < 2 || shape_y.size() < 2) {
return false;
}
// Checks outer dimensions (i.e., the dimensions higher than 2D) are
// broadcastable. If true, then get the broadcasted shape for outer
// dimension.
if (!OpTrait::util::getBroadcastedShape(
shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
return false;
}
const int x_row =
matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
const int x_col =
!matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
const int y_row =
matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
const int y_col =
!matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
// Checks that matrix multiply can perform a valid contraction.
if (x_col != y_row) {
result_shape.clear();
return false;
}
result_shape.push_back(x_row);
result_shape.push_back(y_col);
return true;
};
return RewriteOp(op, rewriter, get_broadcasted_shape);
}
template <typename Op>
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
Operation* op, PatternRewriter& rewriter) const {
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
if (eq_op && eq_op.incompatible_shape_error()) return RewriteOp(op, rewriter);
if (eq_op && eq_op.incompatible_shape_error())
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
return failure();
}
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
Operation* op, PatternRewriter& rewriter) const {
Operation* op, PatternRewriter& rewriter,
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
const {
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
return failure();
@ -102,12 +161,16 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
.dyn_cast_or_null<RankedTensorType>();
if (!argument_type || !argument_type.hasStaticShape()) continue;
// Get the unbroadcasted shapes in the operand order.
std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
operand_shapes[i] = broadcast_arg_type.getShape();
operand_shapes[1 - i] = argument_type.getShape();
// Check that the input of the broadcast and the other operand is broadcast
// compatible.
llvm::SmallVector<int64_t, 4> broadcasted_shape;
if (!OpTrait::util::getBroadcastedShape(broadcast_arg_type.getShape(),
argument_type.getShape(),
broadcasted_shape))
if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
broadcasted_shape))
continue;
// Check that an implicit broadcast between the operand of the broadcast and