Allow static result shape for unranked operand in shape verifier

Previously, a static result shape for an unranked operand produced an error in
shape verifier. This was too restrictive because shape inference is often
incomplete at this point.

PiperOrigin-RevId: 312167322
Change-Id: Ia198f07699174a4ea3c77099c9408def95e058be
This commit is contained in:
Michael Gester 2020-05-18 15:35:17 -07:00 committed by TensorFlower Gardener
parent f5c5747f13
commit 94108993a3
2 changed files with 9 additions and 6 deletions

View File

@ -2603,8 +2603,11 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
<< variadic_idx_str << " to match rank of operand"
<< variadic_idx_str;
} else if (result_ranked_type.hasStaticShape()) {
// The operand is an unranked tensor, verify that the result is dynamic.
return op->emitOpError("requires dynamic shape result")
// The operand is an unranked tensor, print a warning if the result
// is static.
// Note: We do not handle this situation as an error, this would be too
// restrictive due to incompleteness of shape inference at this point.
op->emitWarning("has static shape result")
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
}

View File

@ -1326,7 +1326,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<*xf32>):
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
// expected-warning @+1 {{has static shape result for unranked operand}}
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}
@ -1370,7 +1370,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<*xf32>):
// expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}}
// expected-warning @+1 {{has static shape result #1 for unranked operand #1}}
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>)
return %0#1 : tensor<2xi32>
}
@ -1428,7 +1428,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x1
// -----
func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> {
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
// expected-warning @+1 {{has static shape result for unranked operand}}
%0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}