Verify and Canonicalize static shaped BatchMatMul op to BatchMatMulV2 op

BatchMatMul op requires batch dimensions to match while BatchMatMulV2 op requires batch dimensions to be broadcast compatible. So, if the batch dimensions of BatchMatMul are matching, it can be canonicalized to V2 op.

PiperOrigin-RevId: 346189258
Change-Id: I4bf055041cec42157c1d464a4d732597f44409c9
This commit is contained in:
Smit Hinsu 2020-12-07 15:08:47 -08:00 committed by TensorFlower Gardener
parent 0b66713efa
commit be7b8874d0
6 changed files with 79 additions and 39 deletions

View File

@ -870,6 +870,10 @@ It is computed as:
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let verifier = [{
return Verify(*this);
}];
}
def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> {

View File

@ -157,19 +157,13 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
}
//===----------------------------------------------------------------------===//
// BatchMatMulOp
// BatchMatMulV2Op & BatchMatMulOp
//===----------------------------------------------------------------------===//
void BatchMatMulOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<BatchMatMulToMatMul>(context);
}
//===----------------------------------------------------------------------===//
// BatchMatMulV2Op
//===----------------------------------------------------------------------===//
static LogicalResult Verify(BatchMatMulV2Op op) {
template <typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, BatchMatMulOp, BatchMatMulV2Op>::value>::type * = nullptr>
static LogicalResult Verify(OpT op) {
if (!HasRankAtLeast(op.x(), 2)) {
return op.emitOpError("requires lhs operand to have rank at least two");
}
@ -185,17 +179,34 @@ static LogicalResult Verify(BatchMatMulV2Op op) {
ArrayRef<int64_t> x_shape = x_ty.getShape();
ArrayRef<int64_t> y_shape = y_ty.getShape();
// Check broadcast compatibility if both input shapes are known.
llvm::SmallVector<int64_t, 4> result_batch_shape;
llvm::ArrayRef<int64_t> x_batches = x_shape.drop_back(2);
llvm::ArrayRef<int64_t> y_batches = y_shape.drop_back(2);
// Check compatibility of batch dimensions if both input shapes are known.
// BatchMatMul should have exactly the same batch dimensions and
// BatchMatMulV2 should have broadcastable batch dimensions.
//
// The last two dimensions are non-batch dimensions that don't need to
// participate in batch dimension compatibility check.
llvm::SmallVector<int64_t, 4> result_batch_shape;
if (!OpTrait::util::getBroadcastedShape(
x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape))
return op.emitOpError()
<< "found incompatible broadcast batch dimensions for lhs shape "
<< x_ty << " and rhs shape " << y_ty;
if (std::is_same<OpT, BatchMatMulOp>()) {
for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) {
int64_t x_dim = std::get<0>(dim_pairs);
int64_t y_dim = std::get<1>(dim_pairs);
if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) &&
x_dim != y_dim) {
return op.emitOpError()
<< "found mismatching batch dimensions for lhs shape " << x_ty
<< " and rhs shape " << y_ty;
}
}
} else {
if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches,
result_batch_shape))
return op.emitOpError()
<< "found incompatible broadcast batch dimensions for lhs shape "
<< x_ty << " and rhs shape " << y_ty;
}
RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
if (!output_ty) return success();
@ -245,6 +256,11 @@ static LogicalResult Verify(BatchMatMulV2Op op) {
return success();
}
void BatchMatMulOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<BatchMatMulToV2>(context);
}
void BatchMatMulV2Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<BatchMatMulV2ToMatMul>(context);

View File

@ -16,6 +16,20 @@ func @tfAssertFalse(%arg0: tensor<1x1x6x2xf32>) {
return
}
// CHECK-LABEL: testBatchMatMulToV2
func @testBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>) -> tensor<2x3x7xf32> {
// CHECK: tf.BatchMatMulV2
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
return %0: tensor<2x3x7xf32>
}
// CHECK-LABEL: testDynamicBatchMatMulToV2
func @testDynamicBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<?x5x7xf32>) -> tensor<2x3x7xf32> {
// CHECK: tf.BatchMatMul
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<?x5x7xf32>) -> tensor<2x3x7xf32>
return %0: tensor<2x3x7xf32>
}
// CHECK-LABEL: testBatchMatMulToMatMul
func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32>

View File

@ -3530,6 +3530,22 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf
// -----
// Legal BatchMatMul op.
func @testBatchMatMul(%lhs: tensor<2x?x2x?x3x5xf32>, %rhs: tensor<2x2x?x?x5x7xf32>) {
%0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<2x?x2x?x3x5xf32>, tensor<2x2x?x?x5x7xf32>) -> tensor<2x?x?x?x3x7xf32>
return
}
// -----
// Mismatching batch dimensions.
func @testBatchMatMul(%lhs: tensor<1x3x5xf32>, %rhs: tensor<2x5x7xf32>) {
// expected-error @+1 {{found mismatching batch dimensions for lhs shape 'tensor<1x3x5xf32>' and rhs shape 'tensor<2x5x7xf32>'}}
%0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<1x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32>
}
// -----
func @testBatchMatMulV2(%lhs: tensor<f32>, %rhs: tensor<10x10xf32>) {
// expected-error @+1 {{requires lhs operand to have rank at least two}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<f32>, tensor<10x10xf32>) -> tensor<10x10xf32>

View File

@ -219,23 +219,3 @@ func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tenso
// CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32>
// CHECK: return %[[v0]] : tensor<4x6xf32>
}
// -----
func @batchMatMulVectorLhsInputMatchFailure(%arg0: tensor<10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
return %0 : tensor<10x20xf32>
// CHECK-LABEL: batchMatMulVectorLhs
// CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
}
// -----
func @batchMatMulVectorRhsInputMatchFailure(%arg0: tensor<10x20xf32>, %arg1: tensor<10xf32>) -> tensor<10x20xf32> {
%0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32>
return %0 : tensor<10x20xf32>
// CHECK-LABEL: batchMatMulVectorRhs
// CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32>
}

View File

@ -55,6 +55,16 @@ def AddV2OfNegRight : Pat<(TF_AddV2Op $arg0, (TF_NegOp $arg1)),
// BatchMatMul op patterns.
//===----------------------------------------------------------------------===//
// Static shaped operands in a legal BatchMatMul op will have matching batch
// dimensions and can be upgraded to the BatchMatMulV2 op. Canonicalizing
// dynamically shaped operands is not correct as that will execute ops that
// have non matching batch dimensions but are broadcastable which should fail
// with V1.
def BatchMatMulToV2 :
Pat<(TF_BatchMatMulOp AnyStaticShapeTensor:$x, AnyStaticShapeTensor:$y,
$adj_x, $adj_y),
(TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y)>;
def BatchMatMulToMatMul : Pat<(TF_BatchMatMulOp $x, $y, $adj_x, $adj_y),
(TF_MatMulOp $x, $y, $adj_x, $adj_y),
[(IsRank2Tensor $x), (IsRank2Tensor $y)]>;