diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 78623ca3c61..69b8f15320f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2603,9 +2603,12 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, << variadic_idx_str << " to match rank of operand" << variadic_idx_str; } else if (result_ranked_type.hasStaticShape()) { - // The operand is an unranked tensor, verify that the result is dynamic. - return op->emitOpError("requires dynamic shape result") - << variadic_idx_str << " for unranked operand" << variadic_idx_str; + // The operand is an unranked tensor, print a warning if the result + // is static. + // Note: We do not handle this situation as an error, this would be too + // restrictive due to incompleteness of shape inference at this point. + op->emitWarning("has static shape result") + << variadic_idx_str << " for unranked operand" << variadic_idx_str; } Type element_type = result_ranked_type.getElementType(); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index ffa287e0e53..3560fec7b7d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1326,7 +1326,7 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } @@ -1370,7 +1370,7 @@ func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}} + // expected-warning @+1 {{has static shape result #1 for unranked operand #1}} %0:2 = "tf.ShapeN"(%arg0, %arg0) : (tensor<*xf32>, tensor<*xf32>) -> (tensor, tensor<2xi32>) return %0#1 : tensor<2xi32> } @@ -1428,7 +1428,7 @@ func @testVariableShapeMismatchDim(%arg0: tensor<*x!tf.resource>>) -> tensor<2xi32> { - // expected-error @+1 {{requires dynamic shape result for unranked operand}} + // expected-warning @+1 {{has static shape result for unranked operand}} %0 = "tf.VariableShape"(%arg0) {output = "tfdtype$DT_INT32"} : (tensor<*x!tf.resource>>) -> tensor<2xi32> return %0 : tensor<2xi32> }