Add verifier for tf.BatchToSpace.
The verifier checks for input, output and crops for whether the op is valid or not. If the contents of crops can be determined, its values will be used directly in checking the input and output shapes. PiperOrigin-RevId: 316691063 Change-Id: Idd495d2102604e267cdaa5f45a21c0bfa3dcbcb0
This commit is contained in:
parent
878ac5ae83
commit
8a08194685
@ -820,6 +820,10 @@ followed by cropping along the `height` and `width` dimensions.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
|
||||
|
@ -695,6 +695,149 @@ void BatchMatMulV2Op::getCanonicalizationPatterns(
|
||||
results.insert<BatchMatMulV2ToMatMul>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BatchToSpaceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(BatchToSpaceOp op) {
|
||||
// Op already has a constraint that block_size >= 2.
|
||||
int64_t block_size = op.block_size().getSExtValue();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
|
||||
auto input_type = op.input().getType().cast<TensorType>();
|
||||
if (input_type.hasRank()) {
|
||||
if (input_type.getRank() != 4)
|
||||
return op.emitOpError()
|
||||
<< "requires input to be a 4D tensor, but got " << input_type;
|
||||
|
||||
int64_t input_batch = input_type.getDimSize(0);
|
||||
if (input_batch != ShapedType::kDynamicSize &&
|
||||
input_batch % (block_size * block_size) != 0) {
|
||||
return op.emitOpError()
|
||||
<< "requires input batch (dimension 0) to be evenly divisible "
|
||||
"by (block_size * block_size), but got input batch "
|
||||
<< input_batch << " and block_size " << block_size;
|
||||
}
|
||||
|
||||
input_shape.assign(input_type.getShape().begin(),
|
||||
input_type.getShape().end());
|
||||
}
|
||||
|
||||
auto crops_type = op.crops().getType().cast<TensorType>();
|
||||
if (crops_type.hasRank()) {
|
||||
if (crops_type.getRank() != 2)
|
||||
return op.emitOpError()
|
||||
<< "requires crops to be a 2D tensor, but got " << crops_type;
|
||||
|
||||
auto dim_of_size = [&](int64_t dim, int64_t size) {
|
||||
if (crops_type.isDynamicDim(dim)) return true;
|
||||
return crops_type.getDimSize(dim) == size;
|
||||
};
|
||||
if (!dim_of_size(0, 2) || !dim_of_size(1, 2))
|
||||
return op.emitOpError()
|
||||
<< "requires crops to be a tensor<2x2>, but got " << crops_type;
|
||||
}
|
||||
|
||||
DenseIntElementsAttr crops_attr;
|
||||
// Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]],
|
||||
// and flattened as [crop_top, crop_bottom, crop_left, crop_right]
|
||||
llvm::SmallVector<int64_t, 4> crops_values;
|
||||
if (matchPattern(op.crops(), m_Constant(&crops_attr))) {
|
||||
assert(crops_attr.getNumElements() == 4 &&
|
||||
"tf.BatchToSpace crops must have 4 elements");
|
||||
|
||||
auto crops_range = crops_attr.getIntValues();
|
||||
for (const auto &crops_value : crops_range) {
|
||||
int64_t crops_value_int = crops_value.getSExtValue();
|
||||
if (crops_value_int < 0)
|
||||
return op.emitOpError()
|
||||
<< "requires all crop values to be nonnegative, but got "
|
||||
<< crops_attr;
|
||||
|
||||
crops_values.push_back(crops_value_int);
|
||||
}
|
||||
}
|
||||
|
||||
auto output_type = op.output().getType().cast<TensorType>();
|
||||
if (output_type.hasRank()) {
|
||||
if (output_type.getRank() != 4)
|
||||
return op.emitOpError()
|
||||
<< "requires output to be a 4D tensor, but got " << output_type;
|
||||
|
||||
auto static_dims = [](int64_t dim_a, int64_t dim_b) {
|
||||
return dim_a != ShapedType::kDynamicSize &&
|
||||
dim_b != ShapedType::kDynamicSize;
|
||||
};
|
||||
|
||||
auto output_shape = output_type.getShape();
|
||||
|
||||
// output batch = input batch / (block_size * block_size).
|
||||
int64_t input_batch = input_shape[0];
|
||||
int64_t output_batch = output_shape[0];
|
||||
if (static_dims(input_batch, output_batch) &&
|
||||
(output_batch * block_size * block_size) != input_batch)
|
||||
return op.emitOpError()
|
||||
<< "requires output batch (dimension 0) to be equal to input "
|
||||
"batch (dimension 0) / (block_size * block_size), but got "
|
||||
"output batch "
|
||||
<< output_batch << ", input batch " << input_batch
|
||||
<< ", and block_size " << block_size;
|
||||
|
||||
auto check_spatial_dim = [&](int64_t spatial_dim_index,
|
||||
llvm::StringRef dim_name,
|
||||
llvm::StringRef crop_a_name,
|
||||
llvm::StringRef crop_b_name) -> LogicalResult {
|
||||
int64_t input_dim = input_shape[spatial_dim_index];
|
||||
int64_t output_dim = output_shape[spatial_dim_index];
|
||||
if (!static_dims(input_dim, output_dim)) return success();
|
||||
|
||||
int64_t input_dim_pad = input_dim * block_size;
|
||||
// If crops are unknown, the maximum output spatial dim size is input
|
||||
// spatial dim size * block_size, as crops can be minimum 0.
|
||||
if (crops_values.empty() && output_dim > input_dim * block_size)
|
||||
return op.emitOpError()
|
||||
<< "requires output " << dim_name << " (dimension "
|
||||
<< spatial_dim_index << ") to be less than or equal to input "
|
||||
<< dim_name << " (dimension " << spatial_dim_index
|
||||
<< ") * block_size, but got output " << dim_name << " "
|
||||
<< output_dim << ", input " << dim_name << " " << input_dim
|
||||
<< ", and block_size " << block_size;
|
||||
|
||||
if (!crops_values.empty()) {
|
||||
// output spatial dim = input spatial dim * block_size - crops.
|
||||
int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)];
|
||||
int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1];
|
||||
if (output_dim != input_dim_pad - crop_a - crop_b)
|
||||
return op.emitOpError()
|
||||
<< "requires output " << dim_name << " (dimension "
|
||||
<< spatial_dim_index << ") to be equal to input " << dim_name
|
||||
<< " (dimension " << spatial_dim_index << ") * block_size - "
|
||||
<< crop_a_name << " - " << crop_b_name << ", but got output "
|
||||
<< dim_name << " " << output_dim << ", input " << dim_name
|
||||
<< " " << input_dim << ", " << crop_a_name << " " << crop_a
|
||||
<< ", " << crop_b_name << " " << crop_b << ", and block_size "
|
||||
<< block_size;
|
||||
}
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) ||
|
||||
failed(check_spatial_dim(2, "width", "crop_left", "crop_right")))
|
||||
return failure();
|
||||
|
||||
int64_t input_depth = input_shape[3];
|
||||
int64_t output_depth = output_shape[3];
|
||||
if (static_dims(input_depth, output_depth) && output_depth != input_depth)
|
||||
return op.emitOpError()
|
||||
<< "requires output depth (dimension 3) to be equal to input "
|
||||
"depth (dimension 3), but got output depth "
|
||||
<< output_depth << " and input depth " << input_depth;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BiasAddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2872,4 +2872,141 @@ func @testSendTPUEmbeddingGradients(%x: tensor<512x256xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.BatchToSpace
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
func @testBatchToSpaceDynamic(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceRankedInput(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<*xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<?x?x?x?xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceRankedCrops(%arg0: tensor<*xf32>, %arg1: tensor<?x?xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x?xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceRankedOutput(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<?x?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceStatic(%arg0: tensor<36x8x8x8xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 3 : i64} : (tensor<36x8x8x8xf32>, tensor<2x2xi32>) -> tensor<4x21x17x8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidInputRank(%arg0: tensor<8xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires input to be a 4D tensor, but got 'tensor<8xf32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<8xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidInputBatch(%arg0: tensor<2x4x6x8xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires input batch (dimension 0) to be evenly divisible by (block_size * block_size), but got input batch 2 and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<2x4x6x8xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidCropsRank(%arg0: tensor<*xf32>, %arg1: tensor<?x?x?xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a 2D tensor, but got 'tensor<?x?x?xi32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x?x?xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidCropsFirstDim(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a tensor<2x2>, but got 'tensor<3x?xi32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<3x?xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidCropsSecondDim(%arg0: tensor<*xf32>, %arg1: tensor<?x3xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a tensor<2x2>, but got 'tensor<?x3xi32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x3xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceBadCropValues(%arg0: tensor<*xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[-1, -2], [-3, -4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires all crop values to be nonnegative, but got dense<[[-1, -2], [-3, -4]]> : tensor<2x2xi32>}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<*xf32>, tensor<2x2xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputRank(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output to be a 4D tensor, but got 'tensor<8xf32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputBatch(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output batch (dimension 0) to be equal to input batch (dimension 0) / (block_size * block_size), but got output batch 8, input batch 16, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<8x8x8x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputHeight(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output height (dimension 1) to be less than or equal to input height (dimension 1) * block_size, but got output height 17, input height 8, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x17x8x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputHeightCrops(%arg0: tensor<16x8x8x3xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output height (dimension 1) to be equal to input height (dimension 1) * block_size - crop_top - crop_bottom, but got output height 8, input height 8, crop_top 1, crop_bottom 2, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<2x2xi32>) -> tensor<4x8x9x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputWidth(%arg0: tensor<16x4x4x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output width (dimension 2) to be less than or equal to input width (dimension 2) * block_size, but got output width 9, input width 4, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x4x4x3xf32>, tensor<*xi32>) -> tensor<4x4x9x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputWidthCrops(%arg0: tensor<16x8x8x3xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output width (dimension 2) to be equal to input width (dimension 2) * block_size - crop_left - crop_right, but got output width 8, input width 8, crop_left 3, crop_right 4, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<2x2xi32>) -> tensor<4x13x8x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputDepth(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output depth (dimension 3) to be equal to input depth (dimension 3), but got output depth 8 and input depth 3}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x8x8x8xf32>
|
||||
return
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user