Allow static result shape for unranked operand in shape verifier

Previously, a static result shape for an unranked operand produced an error in
shape verifier. This was too restrictive because shape inference is often
incomplete at this point.

PiperOrigin-RevId: 312167322
Change-Id: Ia198f07699174a4ea3c77099c9408def95e058be
This commit is contained in:
Michael Gester 2020-05-18 15:35:17 -07:00 committed by TensorFlower Gardener
parent f5c5747f13
commit 94108993a3
2 changed files with 9 additions and 6 deletions

View File

@ -2603,8 +2603,11 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
<< variadic_idx_str << " to match rank of operand" << variadic_idx_str << " to match rank of operand"
<< variadic_idx_str; << variadic_idx_str;
} else if (result_ranked_type.hasStaticShape()) { } else if (result_ranked_type.hasStaticShape()) {
// The operand is an unranked tensor, verify that the result is dynamic. // The operand is an unranked tensor, print a warning if the result
return op->emitOpError("requires dynamic shape result") // is static.
// Note: We do not handle this situation as an error, this would be too
// restrictive due to incompleteness of shape inference at this point.
op->emitWarning("has static shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str; << variadic_idx_str << " for unranked operand" << variadic_idx_str;
} }

View File

@ -1326,7 +1326,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<*xf32>): ^bb0(%arg0: tensor<*xf32>):
// expected-error @+1 {{requires dynamic shape result for unranked operand}} // expected-warning @+1 {{has static shape result for unranked operand}}
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32> %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
@ -1370,7 +1370,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<*xf32>): ^bb0(%arg0: tensor<*xf32>):
// expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}} // expected-warning @+1 {{has static shape result #1 for unranked operand #1}}
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>) %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>)
return %0#1 : tensor<2xi32> return %0#1 : tensor<2xi32>
} }
@ -1428,7 +1428,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x1
// ----- // -----
func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> { func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> {
// expected-error @+1 {{requires dynamic shape result for unranked operand}} // expected-warning @+1 {{has static shape result for unranked operand}}
%0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }