Add verifications for BatchMatmulV2 op.
1. Checks the compatibility of the broadcasting dimensions. 2. Check the output rank and shape with input shapes. PiperOrigin-RevId: 330751859 Change-Id: I21cac292ddd1e23977ae989311868ad08de963ae
This commit is contained in:
parent
4550368272
commit
3091cf8ee5
@ -176,6 +176,72 @@ static LogicalResult Verify(BatchMatMulV2Op op) {
|
||||
if (!HasRankAtLeast(op.y(), 2)) {
|
||||
return op.emitOpError("requires rhs operand to have rank at least two");
|
||||
}
|
||||
|
||||
RankedTensorType x_ty = GetRankedTensorTypeForOperand(op.x());
|
||||
RankedTensorType y_ty = GetRankedTensorTypeForOperand(op.y());
|
||||
|
||||
if (!x_ty || !y_ty) return success();
|
||||
|
||||
ArrayRef<int64_t> x_shape = x_ty.getShape();
|
||||
ArrayRef<int64_t> y_shape = y_ty.getShape();
|
||||
|
||||
// Check broadcast compatibility if both input shapes are known.
|
||||
//
|
||||
// The last two dimensions are non-batch dimensions that don't need to
|
||||
// participate in batch dimension compatibility check.
|
||||
|
||||
llvm::SmallVector<int64_t, 4> result_batch_shape;
|
||||
if (!OpTrait::util::getBroadcastedShape(
|
||||
x_shape.drop_back(2), y_shape.drop_back(2), result_batch_shape))
|
||||
return op.emitOpError()
|
||||
<< "found incompatible broadcast batch dimensions for lhs shape "
|
||||
<< x_ty << " and rhs shape " << y_ty;
|
||||
|
||||
RankedTensorType output_ty = GetRankedTensorTypeForOperand(op.output());
|
||||
if (!output_ty) return success();
|
||||
|
||||
int64_t expected_output_rank = std::max(x_ty.getRank(), y_ty.getRank());
|
||||
if (output_ty.getRank() != expected_output_rank)
|
||||
return op.emitOpError()
|
||||
<< "found invalid output rank, expected " << expected_output_rank
|
||||
<< " but got " << output_ty.getRank();
|
||||
|
||||
// Check output batch dim with potential broadcasting.
|
||||
ArrayRef<int64_t> output_shape = output_ty.getShape();
|
||||
for (int i = 0; i < result_batch_shape.size(); ++i) {
|
||||
if (output_shape[i] != ShapedType::kDynamicSize &&
|
||||
output_shape[i] != result_batch_shape[i])
|
||||
return op.emitOpError()
|
||||
<< "has mismatching input batch dimension "
|
||||
<< result_batch_shape[i] << " and output batch dimension "
|
||||
<< output_shape[i];
|
||||
}
|
||||
|
||||
// Check output shape for non-batch dimension, following documentation below.
|
||||
// https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
|
||||
int64_t x_row_dim = x_shape[x_shape.size() - 2];
|
||||
int64_t x_col_dim = x_shape[x_shape.size() - 1];
|
||||
int64_t y_row_dim = y_shape[y_shape.size() - 2];
|
||||
int64_t y_col_dim = y_shape[y_shape.size() - 1];
|
||||
int64_t out_row_dim = output_shape[output_shape.size() - 2];
|
||||
int64_t out_col_dim = output_shape[output_shape.size() - 1];
|
||||
|
||||
int64_t expected_out_row_dim = op.adj_x() ? x_col_dim : x_row_dim;
|
||||
int64_t expected_out_col_dim = op.adj_y() ? y_row_dim : y_col_dim;
|
||||
|
||||
if (expected_out_row_dim != ShapedType::kDynamicSize &&
|
||||
out_row_dim != ShapedType::kDynamicSize &&
|
||||
out_row_dim != expected_out_row_dim)
|
||||
return op.emitOpError()
|
||||
<< "found invalid output dimension on row, expected "
|
||||
<< expected_out_row_dim << " but got " << out_row_dim;
|
||||
if (expected_out_col_dim != ShapedType::kDynamicSize &&
|
||||
out_col_dim != ShapedType::kDynamicSize &&
|
||||
out_col_dim != expected_out_col_dim)
|
||||
return op.emitOpError()
|
||||
<< "found invalid output dimension on col, expected "
|
||||
<< expected_out_col_dim << " but got " << out_col_dim;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -3239,6 +3239,125 @@ func @testBatchMatMulV2(%lhs: tensor<10x10xf32>, %rhs: tensor<f32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testBatchMatMulV2NoBatchDimension
|
||||
func @testBatchMatMulV2NoBatchDimension(%lhs: tensor<5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<5x10xf32>) {
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<5x10xf32>, tensor<10x10xf32>) -> tensor<5x10xf32>
|
||||
return %0 : tensor<5x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testBatchMatMulV2ValidBroadcastingBatchDimension
|
||||
func @testBatchMatMulV2ValidBroadcastingBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) -> (tensor<10x2x5x10xf32>) {
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x10xf32>
|
||||
return %0 : tensor<10x2x5x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testBatchMatMulV2ValidMultiBatchDimension
|
||||
func @testBatchMatMulV2ValidMultiBatchDimension(%lhs: tensor<4x5x1x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x2x5xf32>) {
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x2x5xf32>
|
||||
return %0 : tensor<4x5x1x2x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherXRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10xf32>) {
|
||||
// expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10xf32>'}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10xf32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithSameRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) {
|
||||
// expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<10x2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidBroadcastingBatchDimensionWithHigherYRank(%lhs: tensor<2x5x10xf32>, %rhs: tensor<10x10x10x10xf32>) {
|
||||
// expected-error @+1 {{found incompatible broadcast batch dimensions for lhs shape 'tensor<2x5x10xf32>' and rhs shape 'tensor<10x10x10x10xf32>'}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<2x5x10xf32>, tensor<10x10x10x10xf32>) -> tensor<10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidOutputBatchDimension(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<2x10x10xf32>) {
|
||||
// expected-error @+1 {{has mismatching input batch dimension 2 and output batch dimension 3}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<2x10x10xf32>) -> tensor<10x3x10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidOutputRank(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x1x10x10xf32>) {
|
||||
// expected-error @+1 {{found invalid output rank, expected 4 but got 3}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x1x10x10xf32>) -> tensor<10x5x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidOutputRowDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) {
|
||||
// expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2AdjXInvalidOutputRowDim(%lhs: tensor<10x2x10x5xf32>, %rhs: tensor<10x10xf32>) {
|
||||
// expected-error @+1 {{found invalid output dimension on row, expected 5 but got 10}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<10x2x10x5xf32>, tensor<10x10xf32>) -> tensor<10x2x10x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<10x10xf32>) {
|
||||
// expected-error @+1 {{found invalid output dimension on col, expected 10 but got 5}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<10x2x5x10xf32>, tensor<10x10xf32>) -> tensor<10x2x5x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2AdjYInvalidOutputColDim(%lhs: tensor<10x2x5x10xf32>, %rhs: tensor<4x10xf32>) {
|
||||
// expected-error @+1 {{found invalid output dimension on col, expected 4 but got 10}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_y = true } : (tensor<10x2x5x10xf32>, tensor<4x10xf32>) -> tensor<10x2x5x10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownInputBatchDim
|
||||
func @testBatchMatMulV2PartiallyKnownInputBatchDim(%lhs: tensor<4x5x?x3x2xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x?x2x5xf32>) {
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x?x3x2xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x?x2x5xf32>
|
||||
return %0 : tensor<4x5x?x2x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @testBatchMatMulV2PartiallyKnownMatmulDim
|
||||
func @testBatchMatMulV2PartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x5xf32>) {
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x5xf32>
|
||||
return %0 : tensor<4x5x1x?x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2InvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x?x3xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) {
|
||||
// expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) : (tensor<4x5x1x?x3xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32>
|
||||
return %0 : tensor<4x5x1x?x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchMatMulV2AdjXInvalidPartiallyKnownMatmulDim(%lhs: tensor<4x5x1x3x?xf32>, %rhs: tensor<1x1x3x5xf32>) -> (tensor<4x5x1x?x3xf32>) {
|
||||
// expected-error @+1 {{found invalid output dimension on col, expected 5 but got 3}}
|
||||
%0 = "tf.BatchMatMulV2"(%lhs, %rhs) { adj_x = true } : (tensor<4x5x1x3x?xf32>, tensor<1x1x3x5xf32>) -> tensor<4x5x1x?x3xf32>
|
||||
return %0 : tensor<4x5x1x?x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testDataFormatVecPermuteInvalid1dInput(%x: tensor<5xi32>) {
|
||||
// expected-error @+1 {{requires 1D input of size 4}}
|
||||
%0 = "tf.DataFormatVecPermute"(%x): (tensor<5xi32>) -> tensor<5xi32>
|
||||
|
@ -60,20 +60,20 @@ func @batchmatmulv2_dynamic(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
func @batchmatmulv2_adj_real(%arg0: tensor<5x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<5x4xf32> {
|
||||
func @batchmatmulv2_adj_real(%arg0: tensor<2x5xf32>, %arg1: tensor<4x2xf32>) -> tensor<5x4xf32> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_real
|
||||
// CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = {
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<0> : tensor<1xi64>,
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<1> : tensor<1xi64>}}
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xf32>, tensor<2x4xf32>) -> tensor<5x4xf32>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xf32>, tensor<4x2xf32>) -> tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
func @batchmatmulv2_adj_complex(%arg0: tensor<2x5xcomplex<f32>>, %arg1: tensor<4x2xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK-LABEL: func @batchmatmulv2_adj_complex(
|
||||
// CHECK-SAME: [[LHS:%.*]]: tensor<5x2xcomplex<f32>>, [[RHS:%.*]]: tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK-SAME: [[LHS:%.*]]: tensor<2x5xcomplex<f32>>, [[RHS:%.*]]: tensor<4x2xcomplex<f32>>) -> tensor<5x4xcomplex<f32>> {
|
||||
// CHECK: [[LHSRE:%.*]] = "mhlo.real"([[LHS]])
|
||||
// CHECK: [[LHSIM:%.*]] = "mhlo.imag"([[LHS]])
|
||||
// CHECK: [[LHSIMNEG:%.*]] = "mhlo.negate"([[LHSIM]])
|
||||
@ -84,6 +84,6 @@ func @batchmatmulv2_adj_complex(%arg0: tensor<5x2xcomplex<f32>>, %arg1: tensor<2
|
||||
// CHECK: [[RHSCONJ:%.*]] = "mhlo.complex"([[RHSRE]], [[RHSIMNEG]])
|
||||
// CHECK: shape.shape_of [[LHSCONJ]]
|
||||
// CHECK: shape.shape_of [[RHSCONJ]]
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<5x2xcomplex<f32>>, tensor<2x4xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
%0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = true, adj_y = true, device = ""} : (tensor<2x5xcomplex<f32>>, tensor<4x2xcomplex<f32>>) -> tensor<5x4xcomplex<f32>>
|
||||
return %0 : tensor<5x4xcomplex<f32>>
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user