Add verifications for BatchMatmulV2 op.

1. Checks the compatibility of the broadcasting dimensions.
2. Check the output rank and shape with input shapes.

PiperOrigin-RevId: 330751859
Change-Id: I21cac292ddd1e23977ae989311868ad08de963ae
This commit is contained in:
Haitang Hu 2020-09-09 10:38:51 -07:00 committed by TensorFlower Gardener
parent 4550368272
commit 3091cf8ee5
3 changed files with 190 additions and 5 deletions

View File

@ -176,6 +176,72 @@ static LogicalResult Verify(BatchMatMulV2Op op) {
if (!HasRankAtLeast(op.y(), 2)) {
return op.emitOpError("requires rhs operand to have rank at least two");
}
RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x());
RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y());
if (!x_ty || !y_ty) return success();
ArrayRef<int64_t> x_shape = x_ty.getShape();
ArrayRef<int64_t> y_shape = y_ty.getShape();
// Check broadcast compatibility if both input shapes are known.
//
// The last two dimensions are non-batch dimensions that don't need to
// participate in batch dimension compatibility check.
llvm::SmallVector<int64_t, 4> result_batch_shape;
if (!OpTrait::util::getBroadcastedShape(
x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape))
return op.emitOpError()
<< "found incompatible broadcast batch dimensions for lhs shape "
<< x_ty << " and rhs shape " << y_ty;
RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
if (!output_ty) return success();
int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank());
if (output_ty.getRank() != expected_output_rank)
return op.emitOpError()
<< "found invalid output rank, expected " << expected_output_rank
<< " but got " << output_ty.getRank();
// Check output batch dim with potential broadcasting.
ArrayRef<int64_t> output_shape = output_ty.getShape();
for (int i = 0; i < result_batch_shape.size(); ++i) {
if (output_shape[i] != ShapedType::kDynamicSize &&
output_shape[i] != result_batch_shape[i])
return op.emitOpError()
<< "has mismatching input batch dimension "
<< result_batch_shape[i] << " and output batch dimension "
<< output_shape[i];
}
// Check output shape for non-batch dimension, following documentation below.
// https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
int64_t x_row_dim = x_shape[x_shape.size() - 2];
int64_t x_col_dim = x_shape[x_shape.size() - 1];
int64_t y_row_dim = y_shape[y_shape.size() - 2];
int64_t y_col_dim = y_shape[y_shape.size() - 1];
int64_t out_row_dim = output_shape[output_shape.size() - 2];
int64_t out_col_dim = output_shape[output_shape.size() - 1];
int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim;
int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim;
if (expected_out_row_dim != ShapedType::kDynamicSize &&
out_row_dim != ShapedType::kDynamicSize &&
out_row_dim != expected_out_row_dim)
return op.emitOpError()
<< "found invalid output dimension on row, expected "
<< expected_out_row_dim << " but got " << out_row_dim;
if (expected_out_col_dim != ShapedType::kDynamicSize &&
out_col_dim != ShapedType::kDynamicSize &&
out_col_dim != expected_out_col_dim)
return op.emitOpError()
<< "found invalid output dimension on col, expected "
<< expected_out_col_dim << " but got " << out_col_dim;
return success();
}

View File

@ -3239,6 +3239,125 @@ func @testBatchMatMulV2(%lhs: tensor<10x10xf32>, %rhs: tensor<f32>) {
// -----
// CHECK-LABEL: func @testBatchMatMulV2NoBatchDimension
func @testBatchMatMulV2NoBatchDimension(%lhs: tensor<5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<5x10xf32>) {
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<5x10xf32>, tensor<10x10xf32>) -> tensor<5x10xf32>
return %0 : tensor<5x10xf32>
}
// -----
// CHECK-LABEL: func @testBatchMatMulV2ValidBroadcastingBatchDimension
func @testBatchMatMulV2ValidBroadcastingBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<10x2x5x10xf32>) {
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x10xf32>
return %0 : tensor<10x2x5x10xf32>
}
// -----
// CHECK-LABEL: func @testBatchMatMulV2ValidMultiBatchDimension
func @testBatchMatMulV2ValidMultiBatchDimension(%lhs: tensor<4x5x1x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x2x5xf32>) {
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x2x5xf32>
return %0 : tensor<4x5x1x2x5xf32>
}
// -----
func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherXRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10xf32>) {
// expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10xf32>'}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10xf32>) -> tensor<10x10xf32>
}
// -----
func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithSameRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) {
// expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32>
}
// -----
func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherYRank(%lhs: tensor<2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) {
// expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32>
}
// -----
func @testBatchMatMulV2InvalidOutputBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<2x10x10xf32>) {
// expected-error @+1 {{has mismatching input batch dimension 2 and output batch dimension 3}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<2x10x10xf32>) -> tensor<10x3x10x10xf32>
}
// -----
func @testBatchMatMulV2InvalidOutputRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x1x10x10xf32>) {
// expected-error @+1 {{found invalid output rank, expected 4 but got 3}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x1x10x10xf32>) -> tensor<10x5x10xf32>
}
// -----
func @testBatchMatMulV2InvalidOutputRowDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) {
// expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32>
}
// -----
func @testBatchMatMulV2AdjXInvalidOutputRowDim(%lhs: tensor<10x2x10x5xf32>, %rhs: tensor<10x10xf32>) {
// expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<10x2x10x5xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32>
}
// -----
func @testBatchMatMulV2InvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) {
// expected-error @+1 {{found invalid output dimension on col, expected 10 but got 5}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x5xf32>
}
// -----
func @testBatchMatMulV2AdjYInvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<4x10xf32>) {
// expected-error @+1 {{found invalid output dimension on col, expected 4 but got 10}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_y = true } : (tensor<10x2x5x10xf32>, tensor<4x10xf32>) -> tensor<10x2x5x10xf32>
}
// -----
// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownInputBatchDim
func @testBatchMatMulV2PartiallyKnownInputBatchDim(%lhs: tensor<4x5x?x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x?x2x5xf32>) {
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x?x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x?x2x5xf32>
return %0 : tensor<4x5x?x2x5xf32>
}
// -----
// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownMatmulDim
func @testBatchMatMulV2PartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x5xf32>) {
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x5xf32>
return %0 : tensor<4x5x1x?x5xf32>
}
// -----
func @testBatchMatMulV2InvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) {
// expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32>
return %0 : tensor<4x5x1x?x3xf32>
}
// -----
func @testBatchMatMulV2AdjXInvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x3x?xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) {
// expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}}
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x?xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32>
return %0 : tensor<4x5x1x?x3xf32>
}
// -----
func @testDataFormatVecPermuteInvalid1dInput(%x: tensor<5xi32>) {
// expected-error @+1 {{requires 1D input of size 4}}
%0 = "tf.DataFormatVecPermute"(%x): (tensor<5xi32>) -> tensor<5xi32>

View File

@ -60,20 +60,20 @@ func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
return %0 : tensor<?x?x?xf32>
}
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
func @batchmatmulv2_adj_real(%arg0: tensor<2x5xf32>, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> {
// CHECK-LABEL: func @batchmatmulv2_adj_real
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>,
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>,
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xf32>, tensor<4x2xf32>) -> tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
func @batchmatmulv2_adj_complex(%arg0: tensor<2x5xcomplex<f32>>, %arg1: tensor<4x2xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK-LABEL: func @batchmatmulv2_adj_complex(
// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex<f32>>, [[RHS:%.*]]: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK-SAME: [[LHS:%.*]]: tensor<2x5xcomplex<f32>>, [[RHS:%.*]]: tensor<4x2xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
// CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]])
// CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]])
// CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]])
@ -84,6 +84,6 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2
// CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]])
// CHECK: shape.shape_of [[LHSCONJ]]
// CHECK: shape.shape_of [[RHSCONJ]]
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xcomplex<f32>>, tensor<4x2xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
return %0 : tensor<5x4xcomplex<f32>>
}