From be7b8874d09b6dfd7681c604908ce529392d6526 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Mon, 7 Dec 2020 15:08:47 -0800 Subject: [PATCH] Verify and Canonicalize static shaped BatchMatMul op to BatchMatMulV2 op BatchMatMul op requires batch dimensions to match while BatchMatMulV2 op requires batch dimensions to be broadcast compatible. So, if the batch dimensions of BatchMatMul are matching, it can be canonicalized to V2 op. PiperOrigin-RevId: 346189258 Change-Id: I4bf055041cec42157c1d464a4d732597f44409c9 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 4 ++ .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 54 ++++++++++++------- .../mlir/tensorflow/tests/canonicalize.mlir | 14 +++++ .../mlir/tensorflow/tests/tf-ops.mlir | 16 ++++++ .../tensorflow/tests/unroll-batch-matmul.mlir | 20 ------- .../tensorflow/transforms/canonicalize.td | 10 ++++ 6 files changed, 79 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index b0cc6f97468..91f8d0ebd9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -870,6 +870,10 @@ It is computed as: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; } def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 8271e30f50e..8dd449c9340 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -157,19 +157,13 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, } //===----------------------------------------------------------------------===// -// BatchMatMulOp +// BatchMatMulV2Op & BatchMatMulOp //===----------------------------------------------------------------------===// -void BatchMatMulOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchMatMulV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BatchMatMulV2Op op) { +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { if (!HasRankAtLeast(op.x(), 2)) { return op.emitOpError("requires lhs operand to have rank at least two"); } @@ -185,17 +179,34 @@ static LogicalResult Verify(BatchMatMulV2Op op) { ArrayRef x_shape = x_ty.getShape(); ArrayRef y_shape = y_ty.getShape(); - // Check broadcast compatibility if both input shapes are known. + llvm::SmallVector result_batch_shape; + llvm::ArrayRef x_batches = x_shape.drop_back(2); + llvm::ArrayRef y_batches = y_shape.drop_back(2); + + // Check compatibility of batch dimensions if both input shapes are known. + // BatchMatMul should have exactly the same batch dimensions and + // BatchMatMulV2 should have broadcastable batch dimensions. // // The last two dimensions are non-batch dimensions that don't need to // participate in batch dimension compatibility check. - - llvm::SmallVector result_batch_shape; - if (!OpTrait::util::getBroadcastedShape( - x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape)) - return op.emitOpError() - << "found incompatible broadcast batch dimensions for lhs shape " - << x_ty << " and rhs shape " << y_ty; + if (std::is_same()) { + for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) { + int64_t x_dim = std::get<0>(dim_pairs); + int64_t y_dim = std::get<1>(dim_pairs); + if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) && + x_dim != y_dim) { + return op.emitOpError() + << "found mismatching batch dimensions for lhs shape " << x_ty + << " and rhs shape " << y_ty; + } + } + } else { + if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches, + result_batch_shape)) + return op.emitOpError() + << "found incompatible broadcast batch dimensions for lhs shape " + << x_ty << " and rhs shape " << y_ty; + } RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output()); if (!output_ty) return success(); @@ -245,6 +256,11 @@ static LogicalResult Verify(BatchMatMulV2Op op) { return success(); } +void BatchMatMulOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + void BatchMatMulV2Op::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index fa04ac04c9b..5f3e1e1759b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -16,6 +16,20 @@ func @tfAssertFalse(%arg0: tensor<1x1x6x2xf32>) { return } +// CHECK-LABEL: testBatchMatMulToV2 +func @testBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>) -> tensor<2x3x7xf32> { + // CHECK: tf.BatchMatMulV2 + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> + return %0: tensor<2x3x7xf32> +} + +// CHECK-LABEL: testDynamicBatchMatMulToV2 +func @testDynamicBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor) -> tensor<2x3x7xf32> { + // CHECK: tf.BatchMatMul + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor) -> tensor<2x3x7xf32> + return %0: tensor<2x3x7xf32> +} + // CHECK-LABEL: testBatchMatMulToMatMul func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> { %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 8c156e338cb..db7722f4fe9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -3530,6 +3530,22 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf // ----- +// Legal BatchMatMul op. +func @testBatchMatMul(%lhs: tensor<2x?x2x?x3x5xf32>, %rhs: tensor<2x2x?x?x5x7xf32>) { + %0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<2x?x2x?x3x5xf32>, tensor<2x2x?x?x5x7xf32>) -> tensor<2x?x?x?x3x7xf32> + return +} + +// ----- + +// Mismatching batch dimensions. +func @testBatchMatMul(%lhs: tensor<1x3x5xf32>, %rhs: tensor<2x5x7xf32>) { + // expected-error @+1 {{found mismatching batch dimensions for lhs shape 'tensor<1x3x5xf32>' and rhs shape 'tensor<2x5x7xf32>'}} + %0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<1x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> +} + +// ----- + func @testBatchMatMulV2(%lhs: tensor, %rhs: tensor<10x10xf32>) { // expected-error @+1 {{requires lhs operand to have rank at least two}} %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor, tensor<10x10xf32>) -> tensor<10x10xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 7cf5f19523d..ae553f14ac9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -219,23 +219,3 @@ func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tenso // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: return %[[v0]] : tensor<4x6xf32> } - -// ----- - -func @batchMatMulVectorLhsInputMatchFailure(%arg0: tensor<10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { - %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> - return %0 : tensor<10x20xf32> - - // CHECK-LABEL: batchMatMulVectorLhs - // CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> -} - -// ----- - -func @batchMatMulVectorRhsInputMatchFailure(%arg0: tensor<10x20xf32>, %arg1: tensor<10xf32>) -> tensor<10x20xf32> { - %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32> - return %0 : tensor<10x20xf32> - - // CHECK-LABEL: batchMatMulVectorRhs - // CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32> -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 39340dc9d6b..dcd8931d414 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -55,6 +55,16 @@ def AddV2OfNegRight : Pat<(TF_AddV2Op $arg0, (TF_NegOp $arg1)), // BatchMatMul op patterns. //===----------------------------------------------------------------------===// +// Static shaped operands in a legal BatchMatMul op will have matching batch +// dimensions and can be upgraded to the BatchMatMulV2 op. Canonicalizing +// dynamically shaped operands is not correct as that will execute ops that +// have non matching batch dimensions but are broadcastable which should fail +// with V1. +def BatchMatMulToV2 : + Pat<(TF_BatchMatMulOp AnyStaticShapeTensor:$x, AnyStaticShapeTensor:$y, + $adj_x, $adj_y), + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y)>; + def BatchMatMulToMatMul : Pat<(TF_BatchMatMulOp $x, $y, $adj_x, $adj_y), (TF_MatMulOp $x, $y, $adj_x, $adj_y), [(IsRank2Tensor $x), (IsRank2Tensor $y)]>;