Allow static result shape for unranked operand in shape verifier
Previously, a static result shape for an unranked operand produced an error in shape verifier. This was too restrictive because shape inference is often incomplete at this point. PiperOrigin-RevId: 312167322 Change-Id: Ia198f07699174a4ea3c77099c9408def95e058be
This commit is contained in:
parent
f5c5747f13
commit
94108993a3
@ -2603,8 +2603,11 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type,
|
|||||||
<< variadic_idx_str << " to match rank of operand"
|
<< variadic_idx_str << " to match rank of operand"
|
||||||
<< variadic_idx_str;
|
<< variadic_idx_str;
|
||||||
} else if (result_ranked_type.hasStaticShape()) {
|
} else if (result_ranked_type.hasStaticShape()) {
|
||||||
// The operand is an unranked tensor, verify that the result is dynamic.
|
// The operand is an unranked tensor, print a warning if the result
|
||||||
return op->emitOpError("requires dynamic shape result")
|
// is static.
|
||||||
|
// Note: We do not handle this situation as an error, this would be too
|
||||||
|
// restrictive due to incompleteness of shape inference at this point.
|
||||||
|
op->emitWarning("has static shape result")
|
||||||
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
|
<< variadic_idx_str << " for unranked operand" << variadic_idx_str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1326,7 +1326,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
|
|||||||
|
|
||||||
func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
|
func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
|
||||||
^bb0(%arg0: tensor<*xf32>):
|
^bb0(%arg0: tensor<*xf32>):
|
||||||
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
|
// expected-warning @+1 {{has static shape result for unranked operand}}
|
||||||
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32>
|
%0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32>
|
||||||
return %0 : tensor<2xi32>
|
return %0 : tensor<2xi32>
|
||||||
}
|
}
|
||||||
@ -1370,7 +1370,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> {
|
|||||||
|
|
||||||
func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
|
func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> {
|
||||||
^bb0(%arg0: tensor<*xf32>):
|
^bb0(%arg0: tensor<*xf32>):
|
||||||
// expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}}
|
// expected-warning @+1 {{has static shape result #1 for unranked operand #1}}
|
||||||
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>)
|
%0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<?xi32>, tensor<2xi32>)
|
||||||
return %0#1 : tensor<2xi32>
|
return %0#1 : tensor<2xi32>
|
||||||
}
|
}
|
||||||
@ -1428,7 +1428,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource<tensor<1x32x32x1
|
|||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> {
|
func @testVariableShapeWrongResultDimDynamic(%arg0: tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32> {
|
||||||
// expected-error @+1 {{requires dynamic shape result for unranked operand}}
|
// expected-warning @+1 {{has static shape result for unranked operand}}
|
||||||
%0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32>
|
%0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource<tensor<*xf32>>>) -> tensor<2xi32>
|
||||||
return %0 : tensor<2xi32>
|
return %0 : tensor<2xi32>
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user