Verify and Canonicalize static shaped BatchMatMul op to BatchMatMulV2 op
BatchMatMul op requires batch dimensions to match while BatchMatMulV2 op requires batch dimensions to be broadcast compatible. So, if the batch dimensions of BatchMatMul are matching, it can be canonicalized to V2 op. PiperOrigin-RevId: 346189258 Change-Id: I4bf055041cec42157c1d464a4d732597f44409c9
This commit is contained in:
parent
0b66713efa
commit
be7b8874d0
@ -870,6 +870,10 @@ It is computed as:
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
|
|
||||||
|
let verifier = [{
|
||||||
|
return Verify(*this);
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {
|
||||||
|
@ -157,19 +157,13 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BatchMatMulOp
|
// BatchMatMulV2Op & BatchMatMulOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
void BatchMatMulOp::getCanonicalizationPatterns(
|
template <typename OpT,
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
typename std::enable_if<llvm::is_one_of<
|
||||||
results.insert<BatchMatMulToMatMul>(context);
|
OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr>
|
||||||
}
|
static LogicalResult Verify(OpT op) {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BatchMatMulV2Op
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
static LogicalResult Verify(BatchMatMulV2Op op) {
|
|
||||||
if (!HasRankAtLeast(op.x(), 2)) {
|
if (!HasRankAtLeast(op.x(), 2)) {
|
||||||
return op.emitOpError("requires lhs operand to have rank at least two");
|
return op.emitOpError("requires lhs operand to have rank at least two");
|
||||||
}
|
}
|
||||||
@ -185,17 +179,34 @@ static LogicalResult Verify(BatchMatMulV2Op op) {
|
|||||||
ArrayRef<int64_t> x_shape = x_ty.getShape();
|
ArrayRef<int64_t> x_shape = x_ty.getShape();
|
||||||
ArrayRef<int64_t> y_shape = y_ty.getShape();
|
ArrayRef<int64_t> y_shape = y_ty.getShape();
|
||||||
|
|
||||||
// Check broadcast compatibility if both input shapes are known.
|
llvm::SmallVector<int64_t, 4> result_batch_shape;
|
||||||
|
llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2);
|
||||||
|
llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2);
|
||||||
|
|
||||||
|
// Check compatibility of batch dimensions if both input shapes are known.
|
||||||
|
// BatchMatMul should have exactly the same batch dimensions and
|
||||||
|
// BatchMatMulV2 should have broadcastable batch dimensions.
|
||||||
//
|
//
|
||||||
// The last two dimensions are non-batch dimensions that don't need to
|
// The last two dimensions are non-batch dimensions that don't need to
|
||||||
// participate in batch dimension compatibility check.
|
// participate in batch dimension compatibility check.
|
||||||
|
if (std::is_same<OpT, BatchMatMulOp>()) {
|
||||||
llvm::SmallVector<int64_t, 4> result_batch_shape;
|
for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) {
|
||||||
if (!OpTrait::util::getBroadcastedShape(
|
int64_t x_dim = std::get<0>(dim_pairs);
|
||||||
x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape))
|
int64_t y_dim = std::get<1>(dim_pairs);
|
||||||
return op.emitOpError()
|
if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) &&
|
||||||
<< "found incompatible broadcast batch dimensions for lhs shape "
|
x_dim != y_dim) {
|
||||||
<< x_ty << " and rhs shape " << y_ty;
|
return op.emitOpError()
|
||||||
|
<< "found mismatching batch dimensions for lhs shape " << x_ty
|
||||||
|
<< " and rhs shape " << y_ty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches,
|
||||||
|
result_batch_shape))
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "found incompatible broadcast batch dimensions for lhs shape "
|
||||||
|
<< x_ty << " and rhs shape " << y_ty;
|
||||||
|
}
|
||||||
|
|
||||||
RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
|
RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
|
||||||
if (!output_ty) return success();
|
if (!output_ty) return success();
|
||||||
@ -245,6 +256,11 @@ static LogicalResult Verify(BatchMatMulV2Op op) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BatchMatMulOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<BatchMatMulToV2>(context);
|
||||||
|
}
|
||||||
|
|
||||||
void BatchMatMulV2Op::getCanonicalizationPatterns(
|
void BatchMatMulV2Op::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
results.insert<BatchMatMulV2ToMatMul>(context);
|
results.insert<BatchMatMulV2ToMatMul>(context);
|
||||||
|
@ -16,6 +16,20 @@ func @tfAssertFalse(%arg0: tensor<1x1x6x2xf32>) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testBatchMatMulToV2
|
||||||
|
func @testBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>) -> tensor<2x3x7xf32> {
|
||||||
|
// CHECK: tf.BatchMatMulV2
|
||||||
|
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
|
||||||
|
return %0: tensor<2x3x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: testDynamicBatchMatMulToV2
|
||||||
|
func @testDynamicBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<?x5x7xf32>) -> tensor<2x3x7xf32> {
|
||||||
|
// CHECK: tf.BatchMatMul
|
||||||
|
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<?x5x7xf32>) -> tensor<2x3x7xf32>
|
||||||
|
return %0: tensor<2x3x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: testBatchMatMulToMatMul
|
// CHECK-LABEL: testBatchMatMulToMatMul
|
||||||
func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> {
|
func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> {
|
||||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32>
|
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32>
|
||||||
|
@ -3530,6 +3530,22 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Legal BatchMatMul op.
|
||||||
|
func @testBatchMatMul(%lhs: tensor<2x?x2x?x3x5xf32>, %rhs: tensor<2x2x?x?x5x7xf32>) {
|
||||||
|
%0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<2x?x2x?x3x5xf32>, tensor<2x2x?x?x5x7xf32>) -> tensor<2x?x?x?x3x7xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Mismatching batch dimensions.
|
||||||
|
func @testBatchMatMul(%lhs: tensor<1x3x5xf32>, %rhs: tensor<2x5x7xf32>) {
|
||||||
|
// expected-error @+1 {{found mismatching batch dimensions for lhs shape 'tensor<1x3x5xf32>' and rhs shape 'tensor<2x5x7xf32>'}}
|
||||||
|
%0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<1x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @testBatchMatMulV2(%lhs: tensor<f32>, %rhs: tensor<10x10xf32>) {
|
func @testBatchMatMulV2(%lhs: tensor<f32>, %rhs: tensor<10x10xf32>) {
|
||||||
// expected-error @+1 {{requires lhs operand to have rank at least two}}
|
// expected-error @+1 {{requires lhs operand to have rank at least two}}
|
||||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<f32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<f32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
@ -219,23 +219,3 @@ func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tenso
|
|||||||
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
|
||||||
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
// CHECK: return %[[v0]] : tensor<4x6xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @batchMatMulVectorLhsInputMatchFailure(%arg0: tensor<10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
|
|
||||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
|
||||||
return %0 : tensor<10x20xf32>
|
|
||||||
|
|
||||||
// CHECK-LABEL: batchMatMulVectorLhs
|
|
||||||
// CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func @batchMatMulVectorRhsInputMatchFailure(%arg0: tensor<10x20xf32>, %arg1: tensor<10xf32>) -> tensor<10x20xf32> {
|
|
||||||
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32>
|
|
||||||
return %0 : tensor<10x20xf32>
|
|
||||||
|
|
||||||
// CHECK-LABEL: batchMatMulVectorRhs
|
|
||||||
// CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32>
|
|
||||||
}
|
|
||||||
|
@ -55,6 +55,16 @@ def AddV2OfNegRight : Pat<(TF_AddV2Op $arg0, (TF_NegOp $arg1)),
|
|||||||
// BatchMatMul op patterns.
|
// BatchMatMul op patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Static shaped operands in a legal BatchMatMul op will have matching batch
|
||||||
|
// dimensions and can be upgraded to the BatchMatMulV2 op. Canonicalizing
|
||||||
|
// dynamically shaped operands is not correct as that will execute ops that
|
||||||
|
// have non matching batch dimensions but are broadcastable which should fail
|
||||||
|
// with V1.
|
||||||
|
def BatchMatMulToV2 :
|
||||||
|
Pat<(TF_BatchMatMulOp AnyStaticShapeTensor:$x, AnyStaticShapeTensor:$y,
|
||||||
|
$adj_x, $adj_y),
|
||||||
|
(TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y)>;
|
||||||
|
|
||||||
def BatchMatMulToMatMul : Pat<(TF_BatchMatMulOp $x, $y, $adj_x, $adj_y),
|
def BatchMatMulToMatMul : Pat<(TF_BatchMatMulOp $x, $y, $adj_x, $adj_y),
|
||||||
(TF_MatMulOp $x, $y, $adj_x, $adj_y),
|
(TF_MatMulOp $x, $y, $adj_x, $adj_y),
|
||||||
[(IsRank2Tensor $x), (IsRank2Tensor $y)]>;
|
[(IsRank2Tensor $x), (IsRank2Tensor $y)]>;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user