Expand inference pass to refine types even where the shape can't
Previously the shape inference pass would not attempt to refine the type unless their is a refinement of the shape. This is insufficient for shaped types where the element type has a shape. Having resource subtypes propagated is necessary for certain op decompositions, like tf.VariableShape. This will also allow more compile time information about resources for resources passed across functions. PiperOrigin-RevId: 321370649 Change-Id: Iaacce31605f04b556a3c1b3cc050fa52670c9d66
This commit is contained in:
parent
dada5c989e
commit
de473e40fc
@ -499,4 +499,16 @@ func @multiple_blocks_one_return(%arg0: tensor<?xf32>) -> tensor<*xf32> {
|
|||||||
%outputs_2 = "tf.TensorSliceDataset"(%outputs_0) {device = "", output_shapes = [#tf.shape<>]} : (tensor<*xf32>) -> tensor<!tf.variant>
|
%outputs_2 = "tf.TensorSliceDataset"(%outputs_0) {device = "", output_shapes = [#tf.shape<>]} : (tensor<*xf32>) -> tensor<!tf.variant>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test resource result subtypes are propagated to call op results.
|
||||||
|
// CHECK-LABEL: func @pcall_resource_result
|
||||||
|
func @pcall_resource_result(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
|
||||||
|
// CHECK: "tf.StatefulPartitionedCall"
|
||||||
|
// CHECK-SAME: (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>>
|
||||||
|
%0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_resource_result_func} : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
func @pcall_resource_result_func(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<*x!tf.resource<tensor<f32>>> {
|
||||||
|
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -210,6 +210,21 @@ bool CanBeRefined(Type type) {
|
|||||||
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
|
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns whether `original_type` type can be refined with
|
||||||
|
// `potential_refined_type` type.
|
||||||
|
bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
|
||||||
|
if (!CanBeRefined(original_type)) return false;
|
||||||
|
|
||||||
|
auto shape_type = potential_refined_type.dyn_cast<ShapedType>();
|
||||||
|
if (!shape_type) return false;
|
||||||
|
if (shape_type.hasRank()) return true;
|
||||||
|
|
||||||
|
auto element_type_with_subtype =
|
||||||
|
shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
|
||||||
|
return element_type_with_subtype &&
|
||||||
|
!element_type_with_subtype.GetSubtypes().empty();
|
||||||
|
}
|
||||||
|
|
||||||
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
|
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
|
||||||
// called function and propagating the return type.
|
// called function and propagating the return type.
|
||||||
bool InferShapeForCall(Operation* op) {
|
bool InferShapeForCall(Operation* op) {
|
||||||
@ -224,20 +239,18 @@ bool InferShapeForCall(Operation* op) {
|
|||||||
// Map each of the results of the call to the returned type of the
|
// Map each of the results of the call to the returned type of the
|
||||||
// function.
|
// function.
|
||||||
for (auto result : zip(op->getResults(), func.getType().getResults())) {
|
for (auto result : zip(op->getResults(), func.getType().getResults())) {
|
||||||
if (std::get<0>(result).getType() == std::get<1>(result)) continue;
|
auto call_op_result = std::get<0>(result);
|
||||||
// Skip already statically shaped results.
|
auto func_result_type = std::get<1>(result);
|
||||||
if (!CanBeRefined(std::get<0>(result).getType())) continue;
|
if (call_op_result.getType() == func_result_type) continue;
|
||||||
|
if (!CanRefineTypeWith(call_op_result.getType(), func_result_type))
|
||||||
auto shaped_type = std::get<0>(result).getType().cast<ShapedType>();
|
continue;
|
||||||
auto new_type = std::get<1>(result).dyn_cast<RankedTensorType>();
|
|
||||||
if (!new_type) continue;
|
|
||||||
|
|
||||||
// Inserts a cast back to the original type if any user is not in the
|
// Inserts a cast back to the original type if any user is not in the
|
||||||
// TF dialect.
|
// TF dialect.
|
||||||
AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result),
|
AddCastBackForUnsupportedNonTFUses(op, call_op_result, op->getDialect(),
|
||||||
op->getDialect(), shaped_type);
|
call_op_result.getType());
|
||||||
// Finally we inferred the shape and replace the type for this result.
|
// Finally we inferred the shape and replace the type for this result.
|
||||||
std::get<0>(result).setType(new_type);
|
call_op_result.setType(func_result_type);
|
||||||
changed = true;
|
changed = true;
|
||||||
}
|
}
|
||||||
return changed;
|
return changed;
|
||||||
|
Loading…
Reference in New Issue
Block a user