From 3091cf8ee54965eeda769b8fd93ed9266b47a127 Mon Sep 17 00:00:00 2001 From: Haitang Hu Date: Wed, 9 Sep 2020 10:38:51 -0700 Subject: [PATCH] Add verifications for BatchMatmulV2 op. 1. Checks the compatibility of the broadcasting dimensions. 2. Check the output rank and shape with input shapes. PiperOrigin-RevId: 330751859 Change-Id: I21cac292ddd1e23977ae989311868ad08de963ae --- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 66 ++++++++++ .../mlir/tensorflow/tests/tf-ops.mlir | 119 ++++++++++++++++++ .../xla/tests/legalize-tf-BatchMatMulV2.mlir | 10 +- 3 files changed, 190 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 0bf11f60700..10922bb19a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -176,6 +176,72 @@ static LogicalResult Verify(BatchMatMulV2Op op) { if (!HasRankAtLeast(op.y(), 2)) { return op.emitOpError("requires rhs operand to have rank at least two"); } + + RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x()); + RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y()); + + if (!x_ty || !y_ty) return success(); + + ArrayRef x_shape = x_ty.getShape(); + ArrayRef y_shape = y_ty.getShape(); + + // Check broadcast compatibility if both input shapes are known. + // + // The last two dimensions are non-batch dimensions that don't need to + // participate in batch dimension compatibility check. + + llvm::SmallVector result_batch_shape; + if (!OpTrait::util::getBroadcastedShape( + x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape)) + return op.emitOpError() + << "found incompatible broadcast batch dimensions for lhs shape " + << x_ty << " and rhs shape " << y_ty; + + RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output()); + if (!output_ty) return success(); + + int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank()); + if (output_ty.getRank() != expected_output_rank) + return op.emitOpError() + << "found invalid output rank, expected " << expected_output_rank + << " but got " << output_ty.getRank(); + + // Check output batch dim with potential broadcasting. + ArrayRef output_shape = output_ty.getShape(); + for (int i = 0; i < result_batch_shape.size(); ++i) { + if (output_shape[i] != ShapedType::kDynamicSize && + output_shape[i] != result_batch_shape[i]) + return op.emitOpError() + << "has mismatching input batch dimension " + << result_batch_shape[i] << " and output batch dimension " + << output_shape[i]; + } + + // Check output shape for non-batch dimension, following documentation below. + // https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul + int64_t x_row_dim = x_shape[x_shape.size() - 2]; + int64_t x_col_dim = x_shape[x_shape.size() - 1]; + int64_t y_row_dim = y_shape[y_shape.size() - 2]; + int64_t y_col_dim = y_shape[y_shape.size() - 1]; + int64_t out_row_dim = output_shape[output_shape.size() - 2]; + int64_t out_col_dim = output_shape[output_shape.size() - 1]; + + int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim; + int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim; + + if (expected_out_row_dim != ShapedType::kDynamicSize && + out_row_dim != ShapedType::kDynamicSize && + out_row_dim != expected_out_row_dim) + return op.emitOpError() + << "found invalid output dimension on row, expected " + << expected_out_row_dim << " but got " << out_row_dim; + if (expected_out_col_dim != ShapedType::kDynamicSize && + out_col_dim != ShapedType::kDynamicSize && + out_col_dim != expected_out_col_dim) + return op.emitOpError() + << "found invalid output dimension on col, expected " + << expected_out_col_dim << " but got " << out_col_dim; + return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 16ec052bd71..57be5bdddd1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -3239,6 +3239,125 @@ func @testBatchMatMulV2(%lhs: tensor<10x10xf32>, %rhs: tensor) { // ----- +// CHECK-LABEL: func @testBatchMatMulV2NoBatchDimension +func @testBatchMatMulV2NoBatchDimension(%lhs: tensor<5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<5x10xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<5x10xf32>, tensor<10x10xf32>) -> tensor<5x10xf32> + return %0 : tensor<5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2ValidBroadcastingBatchDimension +func @testBatchMatMulV2ValidBroadcastingBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<10x2x5x10xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x10xf32> + return %0 : tensor<10x2x5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2ValidMultiBatchDimension +func @testBatchMatMulV2ValidMultiBatchDimension(%lhs: tensor<4x5x1x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x2x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x2x5xf32> + return %0 : tensor<4x5x1x2x5xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherXRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithSameRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherYRank(%lhs: tensor<2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) { + // expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<2x10x10xf32>) { + // expected-error @+1 {{has mismatching input batch dimension 2 and output batch dimension 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<2x10x10xf32>) -> tensor<10x3x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x1x10x10xf32>) { + // expected-error @+1 {{found invalid output rank, expected 4 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x1x10x10xf32>) -> tensor<10x5x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputRowDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2AdjXInvalidOutputRowDim(%lhs: tensor<10x2x10x5xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<10x2x10x5xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 10 but got 5}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x5xf32> +} + +// ----- + +func @testBatchMatMulV2AdjYInvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<4x10xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 4 but got 10}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_y = true } : (tensor<10x2x5x10xf32>, tensor<4x10xf32>) -> tensor<10x2x5x10xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownInputBatchDim +func @testBatchMatMulV2PartiallyKnownInputBatchDim(%lhs: tensor<4x5x?x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x?x2x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x?x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x?x2x5xf32> + return %0 : tensor<4x5x?x2x5xf32> +} + +// ----- + +// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownMatmulDim +func @testBatchMatMulV2PartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x5xf32>) { + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x5xf32> + return %0 : tensor<4x5x1x?x5xf32> +} + +// ----- + +func @testBatchMatMulV2InvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32> + return %0 : tensor<4x5x1x?x3xf32> +} + +// ----- + +func @testBatchMatMulV2AdjXInvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x3x?xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) { + // expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}} + %0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x?xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32> + return %0 : tensor<4x5x1x?x3xf32> +} + +// ----- + func @testDataFormatVecPermuteInvalid1dInput(%x: tensor<5xi32>) { // expected-error @+1 {{requires 1D input of size 4}} %0 = "tf.DataFormatVecPermute"(%x): (tensor<5xi32>) -> tensor<5xi32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index cffb15022b0..5a07d9303f0 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -60,20 +60,20 @@ func @batchmatmulv2_dynamic(%arg0: tensor, %arg1: tensor) return %0 : tensor } -func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> { +func @batchmatmulv2_adj_real(%arg0: tensor<2x5xf32>, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> { // CHECK-LABEL: func @batchmatmulv2_adj_real // CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>, // CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xf32>, tensor<4x2xf32>) -> tensor<5x4xf32> return %0 : tensor<5x4xf32> } -func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +func @batchmatmulv2_adj_complex(%arg0: tensor<2x5xcomplex>, %arg1: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { // CHECK-LABEL: func @batchmatmulv2_adj_complex( -// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex>, [[RHS:%.*]]: tensor<2x4xcomplex>) -> tensor<5x4xcomplex> { +// CHECK-SAME: [[LHS:%.*]]: tensor<2x5xcomplex>, [[RHS:%.*]]: tensor<4x2xcomplex>) -> tensor<5x4xcomplex> { // CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]]) // CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]]) // CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]]) @@ -84,6 +84,6 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex>, %arg1: tensor<2 // CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]]) // CHECK: shape.shape_of [[LHSCONJ]] // CHECK: shape.shape_of [[RHSCONJ]] - %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex>, tensor<2x4xcomplex>) -> tensor<5x4xcomplex> + %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xcomplex>, tensor<4x2xcomplex>) -> tensor<5x4xcomplex> return %0 : tensor<5x4xcomplex> }