diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 4193edf8cc6..7d2f630869a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -499,4 +499,16 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { %outputs_2 = "tf.TensorSliceDataset"(%outputs_0) {device = "", output_shapes = [#tf.shape<>]} : (tensor<*xf32>) -> tensor return } + + // Test resource result subtypes are propagated to call op results. + // CHECK-LABEL: func @pcall_resource_result + func @pcall_resource_result(%arg0: tensor<*x!tf.resource>>) { + // CHECK: "tf.StatefulPartitionedCall" + // CHECK-SAME: (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_resource_result_func} : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource> + return + } + func @pcall_resource_result_func(%arg0: tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> { + return %arg0 : tensor<*x!tf.resource>> + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index f9c81634ae5..d2e497a1dec 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -210,6 +210,21 @@ bool CanBeRefined(Type type) { shape_type.getElementType().isa()); } +// Returns whether `original_type` type can be refined with +// `potential_refined_type` type. +bool CanRefineTypeWith(Type original_type, Type potential_refined_type) { + if (!CanBeRefined(original_type)) return false; + + auto shape_type = potential_refined_type.dyn_cast(); + if (!shape_type) return false; + if (shape_type.hasRank()) return true; + + auto element_type_with_subtype = + shape_type.getElementType().dyn_cast(); + return element_type_with_subtype && + !element_type_with_subtype.GetSubtypes().empty(); +} + // Infers the shape from a (Stateful)PartionedCall operation by looking up the // called function and propagating the return type. bool InferShapeForCall(Operation* op) { @@ -224,20 +239,18 @@ bool InferShapeForCall(Operation* op) { // Map each of the results of the call to the returned type of the // function. for (auto result : zip(op->getResults(), func.getType().getResults())) { - if (std::get<0>(result).getType() == std::get<1>(result)) continue; - // Skip already statically shaped results. - if (!CanBeRefined(std::get<0>(result).getType())) continue; - - auto shaped_type = std::get<0>(result).getType().cast(); - auto new_type = std::get<1>(result).dyn_cast(); - if (!new_type) continue; + auto call_op_result = std::get<0>(result); + auto func_result_type = std::get<1>(result); + if (call_op_result.getType() == func_result_type) continue; + if (!CanRefineTypeWith(call_op_result.getType(), func_result_type)) + continue; // Inserts a cast back to the original type if any user is not in the // TF dialect. - AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), - op->getDialect(), shaped_type); + AddCastBackForUnsupportedNonTFUses(op, call_op_result, op->getDialect(), + call_op_result.getType()); // Finally we inferred the shape and replace the type for this result. - std::get<0>(result).setType(new_type); + call_op_result.setType(func_result_type); changed = true; } return changed;