- Fix PropagatePotentiallyWrittenWithinUnhandledOp() to mark resource uses within

regions as potentially written when a resource operand is seen. This also fixes
  the case when multiple resources are used as operands in the same unhandled op
- Add test case to demonstrate the issue

PiperOrigin-RevId: 316893534
Change-Id: I9e688a90155efd990eb5ef835c23933825bcbdd0
This commit is contained in:
Rahul Joshi 2020-06-17 08:32:24 -07:00 committed by TensorFlower Gardener
parent 3d4ca5a00a
commit 23be4f5d44
2 changed files with 95 additions and 15 deletions

View File

@ -1,7 +1,7 @@
// RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s
//===----------------------------------------------------------------------===//
// Freezing.
// Immutability.
//===----------------------------------------------------------------------===//
module attributes {tf_saved_model.semantics} {
@ -142,3 +142,89 @@ module attributes {tf_saved_model.semantics} {
// Test running the pass on a module that does not have
// tf_saved_model.semantics.
module {}
// -----
// Test use as an input in unhandled op
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
"tf.unhandled_op"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> ()
return
}
}
// -----
// Test use as a region capture in an unhandled op
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
"tf.unhandled"() ({
%val = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
"tf.unhandled_terminator"() : () -> ()
}) : () -> ()
return
}
}
// -----
// Test use as region capture as well as input in an unhandled op
// to the unhandled op.
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
attributes {tf_saved_model.exported_names = ["f"]} {
%0 = "tf.unhandled"(%arg0) ({
%val = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
"tf.unhandled_terminator"() : () -> ()
}) : (tensor<!tf.resource<tensor<f32>>>) -> (tensor<!tf.resource<tensor<f32>>>)
return
}
}
// -----
// Test multiple global tensors uses as operands for an unhandled op.
module attributes {tf_saved_model.semantics} {
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.> : tensor<f32> } : () -> ()
// CHECK: "tf_saved_model.global_tensor"() {
// CHECK-SAME: is_mutable
// CHECK-SAME: } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor<f32>, value = dense<22.> : tensor<f32> } : () -> ()
func @f(%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @u})
attributes {tf_saved_model.exported_names = ["f"]} {
"tf.unhandled"(%arg0, %arg1) : (tensor<!tf.resource<tensor<f32>>>, tensor<!tf.resource<tensor<f32>>>) -> ()
return
}
}

View File

@ -56,14 +56,14 @@ struct GlobalTensorUse {
using GlobalTensorUsesMap =
std::map<GlobalTensorOp, std::vector<GlobalTensorUse>>;
static bool IsResourceType(Type type) {
bool IsResourceType(Type type) {
if (auto tensor_type = type.dyn_cast<TensorType>()) {
return tensor_type.getElementType().isa<TF::ResourceType>();
}
return false;
}
static bool IsResource(Value value) { return IsResourceType(value.getType()); }
bool IsResource(Value value) { return IsResourceType(value.getType()); }
class ResourceAnalyzer {
public:
@ -129,30 +129,24 @@ class ResourceAnalyzer {
// this errs on the side of being conservative. We should improve
// this by using either a property or a trait that clearly
// identifies ops with resource mutating behavior.
if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) {
return;
}
PropagatePotentiallyWrittenWithinUnhandledOp(op);
});
return success();
}
// If an op is not one of the handled ones, we assume all resource usages
// within its purview are mutating in nature.
bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
void PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) {
for (auto operand : op->getOperands()) {
if (IsResource(operand)) {
SetPotentiallyWritten(operand);
return true;
}
}
bool uses_resources = false;
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) {
if (IsResource(operand->get())) {
SetPotentiallyWritten(operand->get());
uses_resources = true;
}
});
return uses_resources;
}
// Given a funcOp associated with the callee and operands from the
@ -212,7 +206,7 @@ bool IsImmutable(GlobalTensorOp global_tensor,
return true;
}
static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) {
GlobalTensorUsesMap global_tensor_uses;
SymbolTable symbol_table(module);
@ -293,13 +287,13 @@ void OptimizeGlobalTensorsPass::runOnOperation() {
EraseUnusedGlobalTensors(module, global_tensor_uses);
}
} // namespace
// For "opt" to pick up this pass.
static PassRegistration<OptimizeGlobalTensorsPass> pass(
PassRegistration<OptimizeGlobalTensorsPass> pass(
"tf-saved-model-optimize-global-tensors",
"Optimize tf_saved_model.global_tensor's.");
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass() {
return std::make_unique<OptimizeGlobalTensorsPass>();
}