- Fix PropagatePotentiallyWrittenWithinUnhandledOp() to mark resource uses within
regions as potentially written when a resource operand is seen. This also fixes the case when multiple resources are used as operands in the same unhandled op - Add test case to demonstrate the issue PiperOrigin-RevId: 316893534 Change-Id: I9e688a90155efd990eb5ef835c23933825bcbdd0
This commit is contained in:
parent
3d4ca5a00a
commit
23be4f5d44
@ -1,7 +1,7 @@
|
||||
// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Freezing.
|
||||
// Immutability.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
@ -142,3 +142,89 @@ module attributes {tf_saved_model.semantics} {
|
||||
// Test running the pass on a module that does not have
|
||||
// tf_saved_model.semantics.
|
||||
module {}
|
||||
|
||||
// -----
|
||||
|
||||
// Test use as an input in unhandled op
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
"tf.unhandled_op"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Test use as a region capture in an unhandled op
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
"tf.unhandled"() ({
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
"tf.unhandled_terminator"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test use as region capture as well as input in an unhandled op
|
||||
// to the unhandled op.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.unhandled"(%arg0) ({
|
||||
%val = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
"tf.unhandled_terminator"() : () -> ()
|
||||
}) : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<!tf.resource<tensor<f32>>>)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test multiple global tensors uses as operands for an unhandled op.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
|
||||
|
||||
// CHECK: "tf_saved_model.global_tensor"() {
|
||||
// CHECK-SAME: is_mutable
|
||||
// CHECK-SAME: } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
"tf.unhandled"(%arg0, %arg1) : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>) -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -56,14 +56,14 @@ struct GlobalTensorUse {
|
||||
using GlobalTensorUsesMap =
|
||||
std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
|
||||
|
||||
static bool IsResourceType(Type type) {
|
||||
bool IsResourceType(Type type) {
|
||||
if (auto tensor_type = type.dyn_cast<TensorType>()) {
|
||||
return tensor_type.getElementType().isa<TF::ResourceType>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool IsResource(Value value) { return IsResourceType(value.getType()); }
|
||||
bool IsResource(Value value) { return IsResourceType(value.getType()); }
|
||||
|
||||
class ResourceAnalyzer {
|
||||
public:
|
||||
@ -129,30 +129,24 @@ class ResourceAnalyzer {
|
||||
// this errs on the side of being conservative. We should improve
|
||||
// this by using either a property or a trait that clearly
|
||||
// identifies ops with resource mutating behavior.
|
||||
if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) {
|
||||
return;
|
||||
}
|
||||
PropagatePotentiallyWrittenWithinUnhandledOp(op);
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
// If an op is not one of the handled ones, we assume all resource usages
|
||||
// within its purview are mutating in nature.
|
||||
bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
|
||||
void PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (IsResource(operand)) {
|
||||
SetPotentiallyWritten(operand);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
bool uses_resources = false;
|
||||
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
|
||||
if (IsResource(operand->get())) {
|
||||
SetPotentiallyWritten(operand->get());
|
||||
uses_resources = true;
|
||||
}
|
||||
});
|
||||
return uses_resources;
|
||||
}
|
||||
|
||||
// Given a funcOp associated with the callee and operands from the
|
||||
@ -212,7 +206,7 @@ bool IsImmutable(GlobalTensorOp global_tensor,
|
||||
return true;
|
||||
}
|
||||
|
||||
static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
|
||||
GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
|
||||
GlobalTensorUsesMap global_tensor_uses;
|
||||
|
||||
SymbolTable symbol_table(module);
|
||||
@ -293,13 +287,13 @@ void OptimizeGlobalTensorsPass::runOnOperation() {
|
||||
EraseUnusedGlobalTensors(module, global_tensor_uses);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// For "opt" to pick up this pass.
|
||||
static PassRegistration<OptimizeGlobalTensorsPass> pass(
|
||||
PassRegistration<OptimizeGlobalTensorsPass> pass(
|
||||
"tf-saved-model-optimize-global-tensors",
|
||||
"Optimize tf_saved_model.global_tensor's.");
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
|
||||
return std::make_unique<OptimizeGlobalTensorsPass>();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user