Expand inference pass to refine types even where the shape can't

Previously the shape inference pass would not attempt to refine the type unless their is a refinement of the shape. This is insufficient for shaped types where the element type has a shape. Having resource subtypes propagated is necessary for certain op decompositions, like tf.VariableShape. This will also allow more compile time information about resources for resources passed across functions.

PiperOrigin-RevId: 321370649
Change-Id: Iaacce31605f04b556a3c1b3cc050fa52670c9d66
This commit is contained in:
Andy Ly 2020-07-15 09:02:53 -07:00 committed by TensorFlower Gardener
parent dada5c989e
commit de473e40fc
2 changed files with 35 additions and 10 deletions

View File

@ -499,4 +499,16 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
%outputs_2 = "tf.TensorSliceDataset"(%outputs_0) {device = "", output_shapes = [#tf.shape<>]} : (tensor<*xf32>) -> tensor<!tf.variant>
return
}
// Test resource result subtypes are propagated to call op results.
// CHECK-LABEL: func @pcall_resource_result
func @pcall_resource_result(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
// CHECK: "tf.StatefulPartitionedCall"
// CHECK-SAME: (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
%0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_resource_result_func} : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource>
return
}
func @pcall_resource_result_func(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>> {
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
}

View File

@ -210,6 +210,21 @@ bool CanBeRefined(Type type) {
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
}
// Returns whether `original_type` type can be refined with
// `potential_refined_type` type.
bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
if (!CanBeRefined(original_type)) return false;
auto shape_type = potential_refined_type.dyn_cast<ShapedType>();
if (!shape_type) return false;
if (shape_type.hasRank()) return true;
auto element_type_with_subtype =
shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
return element_type_with_subtype &&
!element_type_with_subtype.GetSubtypes().empty();
}
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
// called function and propagating the return type.
bool InferShapeForCall(Operation* op) {
@ -224,20 +239,18 @@ bool InferShapeForCall(Operation* op) {
// Map each of the results of the call to the returned type of the
// function.
for (auto result : zip(op->getResults(), func.getType().getResults())) {
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
// Skip already statically shaped results.
if (!CanBeRefined(std::get<0>(result).getType())) continue;
auto shaped_type = std::get<0>(result).getType().cast<ShapedType>();
auto new_type = std::get<1>(result).dyn_cast<RankedTensorType>();
if (!new_type) continue;
auto call_op_result = std::get<0>(result);
auto func_result_type = std::get<1>(result);
if (call_op_result.getType() == func_result_type) continue;
if (!CanRefineTypeWith(call_op_result.getType(), func_result_type))
continue;
// Inserts a cast back to the original type if any user is not in the
// TF dialect.
AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result),
op->getDialect(), shaped_type);
AddCastBackForUnsupportedNonTFUses(op, call_op_result, op->getDialect(),
call_op_result.getType());
// Finally we inferred the shape and replace the type for this result.
std::get<0>(result).setType(new_type);
call_op_result.setType(func_result_type);
changed = true;
}
return changed;