From 23be4f5d44f28734620fa508d9807d9aca2ce074 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 17 Jun 2020 08:32:24 -0700 Subject: [PATCH] - Fix PropagatePotentiallyWrittenWithinUnhandledOp() to mark resource uses within regions as potentially written when a resource operand is seen. This also fixes the case when multiple resources are used as operands in the same unhandled op - Add test case to demonstrate the issue PiperOrigin-RevId: 316893534 Change-Id: I9e688a90155efd990eb5ef835c23933825bcbdd0 --- ...f_saved_model_optimize_global_tensors.mlir | 88 ++++++++++++++++++- .../transforms/optimize_global_tensors.cc | 22 ++--- 2 files changed, 95 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir index 9d8911d306d..0c68cf0cf64 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir @@ -1,7 +1,7 @@ // RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s //===----------------------------------------------------------------------===// -// Freezing. +// Immutability. //===----------------------------------------------------------------------===// module attributes {tf_saved_model.semantics} { @@ -142,3 +142,89 @@ module attributes {tf_saved_model.semantics} { // Test running the pass on a module that does not have // tf_saved_model.semantics. module {} + +// ----- + +// Test use as an input in unhandled op +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + "tf.unhandled_op"(%arg0) : (tensor>>) -> () + return + } +} + + +// ----- + +// Test use as a region capture in an unhandled op +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + "tf.unhandled"() ({ + %val = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + "tf.unhandled_terminator"() : () -> () + }) : () -> () + return + } +} + +// ----- + +// Test use as region capture as well as input in an unhandled op +// to the unhandled op. +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor, value = dense<22.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}, %arg1: tensor>> {tf_saved_model.bound_input = @u}) + attributes {tf_saved_model.exported_names = ["f"]} { + %0 = "tf.unhandled"(%arg0) ({ + %val = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + "tf.unhandled_terminator"() : () -> () + }) : (tensor>>) -> (tensor>>) + return + } +} + +// ----- + +// Test multiple global tensors uses as operands for an unhandled op. +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor, value = dense<22.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}, %arg1: tensor>> {tf_saved_model.bound_input = @u}) + attributes {tf_saved_model.exported_names = ["f"]} { + "tf.unhandled"(%arg0, %arg1) : (tensor>>, tensor>>) -> () + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index cd8f988fd5f..07cc6203cbd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -56,14 +56,14 @@ struct GlobalTensorUse { using GlobalTensorUsesMap = std::map>; -static bool IsResourceType(Type type) { +bool IsResourceType(Type type) { if (auto tensor_type = type.dyn_cast()) { return tensor_type.getElementType().isa(); } return false; } -static bool IsResource(Value value) { return IsResourceType(value.getType()); } +bool IsResource(Value value) { return IsResourceType(value.getType()); } class ResourceAnalyzer { public: @@ -129,30 +129,24 @@ class ResourceAnalyzer { // this errs on the side of being conservative. We should improve // this by using either a property or a trait that clearly // identifies ops with resource mutating behavior. - if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) { - return; - } + PropagatePotentiallyWrittenWithinUnhandledOp(op); }); return success(); } // If an op is not one of the handled ones, we assume all resource usages // within its purview are mutating in nature. - bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) { + void PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) { for (auto operand : op->getOperands()) { if (IsResource(operand)) { SetPotentiallyWritten(operand); - return true; } } - bool uses_resources = false; visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) { if (IsResource(operand->get())) { SetPotentiallyWritten(operand->get()); - uses_resources = true; } }); - return uses_resources; } // Given a funcOp associated with the callee and operands from the @@ -212,7 +206,7 @@ bool IsImmutable(GlobalTensorOp global_tensor, return true; } -static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) { +GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) { GlobalTensorUsesMap global_tensor_uses; SymbolTable symbol_table(module); @@ -293,13 +287,13 @@ void OptimizeGlobalTensorsPass::runOnOperation() { EraseUnusedGlobalTensors(module, global_tensor_uses); } -} // namespace - // For "opt" to pick up this pass. -static PassRegistration pass( +PassRegistration pass( "tf-saved-model-optimize-global-tensors", "Optimize tf_saved_model.global_tensor's."); +} // namespace + std::unique_ptr> CreateOptimizeGlobalTensorsPass() { return std::make_unique(); }