diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index b0cc6f97468..91f8d0ebd9c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -870,6 +870,10 @@ It is computed as: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; } def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 8271e30f50e..8dd449c9340 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -157,19 +157,13 @@ void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, } //===----------------------------------------------------------------------===// -// BatchMatMulOp +// BatchMatMulV2Op & BatchMatMulOp //===----------------------------------------------------------------------===// -void BatchMatMulOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchMatMulV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BatchMatMulV2Op op) { +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { if (!HasRankAtLeast(op.x(), 2)) { return op.emitOpError("requires lhs operand to have rank at least two"); } @@ -185,17 +179,34 @@ static LogicalResult Verify(BatchMatMulV2Op op) { ArrayRef x_shape = x_ty.getShape(); ArrayRef y_shape = y_ty.getShape(); - // Check broadcast compatibility if both input shapes are known. + llvm::SmallVector result_batch_shape; + llvm::ArrayRef x_batches = x_shape.drop_back(2); + llvm::ArrayRef y_batches = y_shape.drop_back(2); + + // Check compatibility of batch dimensions if both input shapes are known. + // BatchMatMul should have exactly the same batch dimensions and + // BatchMatMulV2 should have broadcastable batch dimensions. // // The last two dimensions are non-batch dimensions that don't need to // participate in batch dimension compatibility check. - - llvm::SmallVector result_batch_shape; - if (!OpTrait::util::getBroadcastedShape( - x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape)) - return op.emitOpError() - << "found incompatible broadcast batch dimensions for lhs shape " - << x_ty << " and rhs shape " << y_ty; + if (std::is_same()) { + for (const auto &dim_pairs : llvm::zip(x_batches, y_batches)) { + int64_t x_dim = std::get<0>(dim_pairs); + int64_t y_dim = std::get<1>(dim_pairs); + if (!ShapedType::isDynamic(x_dim) && !ShapedType::isDynamic(y_dim) && + x_dim != y_dim) { + return op.emitOpError() + << "found mismatching batch dimensions for lhs shape " << x_ty + << " and rhs shape " << y_ty; + } + } + } else { + if (!OpTrait::util::getBroadcastedShape(x_batches, y_batches, + result_batch_shape)) + return op.emitOpError() + << "found incompatible broadcast batch dimensions for lhs shape " + << x_ty << " and rhs shape " << y_ty; + } RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output()); if (!output_ty) return success(); @@ -245,6 +256,11 @@ static LogicalResult Verify(BatchMatMulV2Op op) { return success(); } +void BatchMatMulOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + void BatchMatMulV2Op::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index fa04ac04c9b..5f3e1e1759b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -16,6 +16,20 @@ func @tfAssertFalse(%arg0: tensor<1x1x6x2xf32>) { return } +// CHECK-LABEL: testBatchMatMulToV2 +func @testBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2x5x7xf32>) -> tensor<2x3x7xf32> { + // CHECK: tf.BatchMatMulV2 + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> + return %0: tensor<2x3x7xf32> +} + +// CHECK-LABEL: testDynamicBatchMatMulToV2 +func @testDynamicBatchMatMulToV2(%arg0: tensor<2x3x5xf32>, %arg1: tensor) -> tensor<2x3x7xf32> { + // CHECK: tf.BatchMatMul + %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3x5xf32>, tensor) -> tensor<2x3x7xf32> + return %0: tensor<2x3x7xf32> +} + // CHECK-LABEL: testBatchMatMulToMatMul func @testBatchMatMulToMatMul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<2x2xf32> { %0 = "tf.BatchMatMul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<2x3xf32>, tensor<3x2xf32>) -> tensor<2x2xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 8c156e338cb..db7722f4fe9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -3530,6 +3530,22 @@ func @testParseExampleV2RaggedMismatchedOutputLengths(%serialized: tensor<32x!tf // ----- +// Legal BatchMatMul op. +func @testBatchMatMul(%lhs: tensor<2x?x2x?x3x5xf32>, %rhs: tensor<2x2x?x?x5x7xf32>) { + %0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<2x?x2x?x3x5xf32>, tensor<2x2x?x?x5x7xf32>) -> tensor<2x?x?x?x3x7xf32> + return +} + +// ----- + +// Mismatching batch dimensions. +func @testBatchMatMul(%lhs: tensor<1x3x5xf32>, %rhs: tensor<2x5x7xf32>) { + // expected-error @+1 {{found mismatching batch dimensions for lhs shape 'tensor<1x3x5xf32>' and rhs shape 'tensor<2x5x7xf32>'}} + %0 = "tf.BatchMatMul"(%lhs, %rhs) : (tensor<1x3x5xf32>, tensor<2x5x7xf32>) -> tensor<2x3x7xf32> +} + +// ----- + func @testBatchMatMulV2(%lhs: tensor, %rhs: tensor<10x10xf32>) { // expected-error @+1 {{requires lhs operand to have rank at least two}} %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor, tensor<10x10xf32>) -> tensor<10x10xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 7cf5f19523d..ae553f14ac9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -219,23 +219,3 @@ func @batchMatMulMatrix(%arg0: tensor<4x5xf32>, %arg1: tensor<5x6xf32>) -> tenso // CHECK: %[[v0:.*]] = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> // CHECK: return %[[v0]] : tensor<4x6xf32> } - -// ----- - -func @batchMatMulVectorLhsInputMatchFailure(%arg0: tensor<10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { - %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> - return %0 : tensor<10x20xf32> - - // CHECK-LABEL: batchMatMulVectorLhs - // CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> -} - -// ----- - -func @batchMatMulVectorRhsInputMatchFailure(%arg0: tensor<10x20xf32>, %arg1: tensor<10xf32>) -> tensor<10x20xf32> { - %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32> - return %0 : tensor<10x20xf32> - - // CHECK-LABEL: batchMatMulVectorRhs - // CHECK: %0 = "tf.BatchMatMul"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<10xf32>) -> tensor<10x20xf32> -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 39340dc9d6b..dcd8931d414 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -55,6 +55,16 @@ def AddV2OfNegRight : Pat<(TF_AddV2Op $arg0, (TF_NegOp $arg1)), // BatchMatMul op patterns. //===----------------------------------------------------------------------===// +// Static shaped operands in a legal BatchMatMul op will have matching batch +// dimensions and can be upgraded to the BatchMatMulV2 op. Canonicalizing +// dynamically shaped operands is not correct as that will execute ops that +// have non matching batch dimensions but are broadcastable which should fail +// with V1. +def BatchMatMulToV2 : + Pat<(TF_BatchMatMulOp AnyStaticShapeTensor:$x, AnyStaticShapeTensor:$y, + $adj_x, $adj_y), + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y)>; + def BatchMatMulToMatMul : Pat<(TF_BatchMatMulOp $x, $y, $adj_x, $adj_y), (TF_MatMulOp $x, $y, $adj_x, $adj_y), [(IsRank2Tensor $x), (IsRank2Tensor $y)]>;