From e3ac0822d81d486cde16d6cd32d560d056a42d4e Mon Sep 17 00:00:00 2001 From: Rick Chao Date: Fri, 31 Jul 2020 14:45:42 -0700 Subject: [PATCH] Add compiler pass to remove duplicate 'tf_saved_model.bound_input' bindings. Consolidate identical bound inputs so that resource variables do not alias in modules with tf_saved_model semantics. PiperOrigin-RevId: 324288443 Change-Id: I4ccf9c19f3e2df123667b71560c3d3ae3c751913 --- tensorflow/compiler/mlir/tensorflow/BUILD | 1 - .../mlir/tensorflow/ir/tf_saved_model.cc | 1 - .../tf_saved_model/hash_table_asset_v1.py | 17 ++--- ...odel_deduplicate_bound_input_bindings.mlir | 33 ---------- .../tensorflow/tests/tf_saved_model_ops.mlir | 13 ---- .../tests/tf_saved_model_ops_invalid.mlir | 14 ---- .../deduplicate_bound_input_bindings.cc | 65 ------------------- .../transforms/tf_saved_model_passes.h | 3 - .../mlir/tensorflow/translate/import_model.cc | 4 +- 9 files changed, 6 insertions(+), 145 deletions(-) delete mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_deduplicate_bound_input_bindings.mlir delete mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index c6f0083fc92..518992d03db 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -676,7 +676,6 @@ cc_library( cc_library( name = "tf_saved_model_passes", srcs = [ - "transforms/deduplicate_bound_input_bindings.cc", "transforms/freeze_global_tensors.cc", "transforms/lift_variables_pass.cc", "transforms/optimize_global_tensors.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 94a792ec3db..edfc7feefd5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -337,7 +337,6 @@ LogicalResult VerifyExportedFunc(FuncOp func) { if (auto attr = func.getArgAttrOfType( i, "tf_saved_model.bound_input")) { if (!unique_bound_inputs.insert(attr.getValue()).second) { - if (module.getAttr("tf_saved_model.under_construction")) continue; return func.emitError() << "duplicate 'tf_saved_model.bound_input' binding"; } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py index 4cb931253b3..7e86953eb8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py @@ -27,15 +27,13 @@ import tensorflow.compat.v1 as tf from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 # CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> () -# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"} -# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"} +# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset:.*]]"} # CHECK: func [[init]] -# CHECK-SAME: [[ARG0:%.*]]: tensor {tf_saved_model.bound_input = @[[asset0]]} -# CHECK-SAME: [[ARG1:%.*]]: tensor {tf_saved_model.bound_input = @[[asset1]]} +# CHECK-SAME: [[ARG:%.*]]: tensor {tf_saved_model.bound_input = @[[asset]]} # CHECK-NEXT: [[R0:%.*]] = "tf.HashTableV2"() # CHECK-SAME: shared_name = "[[hash_table:.*]]" -# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG0]]) +# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG]]) def write_vocabulary_file(vocabulary): @@ -50,16 +48,11 @@ def write_vocabulary_file(vocabulary): def test(): - vocabulary_file = write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat']) table_initializer = tf.lookup.TextFileInitializer( - vocabulary_file, tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64, + write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat']), tf.string, + tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER) - # Incur another bound_input on the asset, but with a different sym_name, i.e., - # __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt. table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10) - vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string, - name='asset_filepath') - tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor) x = tf.placeholder(tf.string, shape=(), name='input') r = table.lookup(x) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_deduplicate_bound_input_bindings.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_deduplicate_bound_input_bindings.mlir deleted file mode 100644 index 22fd3d86068..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_deduplicate_bound_input_bindings.mlir +++ /dev/null @@ -1,33 +0,0 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-dedup-bound-input-binding-pass | FileCheck %s - -module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { - // Test case: Remove duplicate bound_input symbols. - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.0> : tensor } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "w", type = tensor, value = dense<43.0> : tensor } : () -> () - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "x", type = tensor, value = dense<44.0> : tensor } : () -> () - // CHECK: func @f - // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @v} - // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @w} - // CHECK: %arg2: tensor>> {tf_saved_model.bound_input = @x} - // CHECK-NOT: %arg3 - // CHECK-NOT: %arg4 - func @f( - %arg0: tensor>> {tf_saved_model.bound_input = @v}, - %arg1: tensor>> {tf_saved_model.bound_input = @w}, - %arg2: tensor>> {tf_saved_model.bound_input = @v}, - %arg3: tensor>> {tf_saved_model.bound_input = @x}, - %arg4: tensor>> {tf_saved_model.bound_input = @v} - ) attributes {tf_saved_model.exported_names = ["f"]} { - // CHECK: "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor - // CHECK: "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor - // CHECK: "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor - // CHECK: "tf.ReadVariableOp"(%arg2) : (tensor>>) -> tensor - // CHECK: "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor - %val0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor - %val1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor - %val2 = "tf.ReadVariableOp"(%arg2) : (tensor>>) -> tensor - %val3 = "tf.ReadVariableOp"(%arg3) : (tensor>>) -> tensor - %val4 = "tf.ReadVariableOp"(%arg4) : (tensor>>) -> tensor - return - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index d2c5509b52d..7156a1fab63 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -76,16 +76,3 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} } } - -// ----- - -module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.0> : tensor } : () -> () - // CHECK: func @f - func @f( - %arg0: tensor>> {tf_saved_model.bound_input = @v}, - %arg1: tensor>> {tf_saved_model.bound_input = @v} - ) attributes {tf_saved_model.exported_names = ["f"]} { - return - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index 714c8908825..dcb889ff99e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -400,17 +400,3 @@ module attributes {tf_saved_model.semantics} { } } - -// ----- - -module attributes {tf_saved_model.semantics} { - - "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.0> : tensor } : () -> () - // expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}} - func @f( - %arg0: tensor>> {tf_saved_model.bound_input = @v}, - %arg1: tensor>> {tf_saved_model.bound_input = @v} - ) attributes {tf_saved_model.exported_names = ["f"]} { - return - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc b/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc deleted file mode 100644 index c1514dfa357..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "llvm/ADT/DenseMap.h" -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" - -namespace mlir { -namespace tf_saved_model { -namespace { - -class DedupBoundInputBindingPass - : public PassWrapper { - public: - void runOnFunction() override; -}; - -void DedupBoundInputBindingPass::runOnFunction() { - FuncOp func = getFunction(); - if (!mlir::tf_saved_model::IsExported(func)) return; - llvm::SmallDenseMap unique_bound_inputs; - llvm::SmallVector arg_indices_to_erase; - for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) { - auto attr = func.getArgAttrOfType( - i, "tf_saved_model.bound_input"); - if (!attr) continue; - auto inserted = unique_bound_inputs.insert(std::make_pair(attr, i)); - if (inserted.second) continue; - auto duplicate_arg = func.getArgument(i); - auto original_arg = func.getArgument(unique_bound_inputs[attr]); - duplicate_arg.replaceAllUsesWith(original_arg); - arg_indices_to_erase.push_back(i); - } - func.eraseArguments(arg_indices_to_erase); -} - -} // namespace - -static PassRegistration pass( - "tf-saved-model-dedup-bound-input-binding-pass", - "Remove duplicate 'tf_saved_model.bound_input' bindings."); - -std::unique_ptr> CreateDedupBoundInputBindingPass() { - return std::make_unique(); -} - -} // namespace tf_saved_model -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h index 59532a2b123..f7a73dc1561 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h @@ -46,9 +46,6 @@ CreateRemoveVariablesInSessionInitializerPass(); std::unique_ptr> CreateLiftVariablesPass( ::tensorflow::Session* session); -// Creates a pass that removes duplicate 'tf_saved_model.bound_input' bindings. -std::unique_ptr> CreateDedupBoundInputBindingPass(); - } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 27385e81262..2c44aaa5c42 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -3368,13 +3368,12 @@ SavedModelSignatureDefImporter::ConvertAssets() { results.reserve(asset_file_defs.size()); mlir::OpBuilder builder(module_->getBodyRegion()); - unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets. for (const auto& asset : asset_file_defs) { auto asset_op = builder.create( module_->getLoc(), /*sym_name=*/ builder.getStringAttr( - absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())), + absl::StrCat("__tf_saved_model_asset_", asset.filename())), /*filename=*/ builder.getStringAttr( io::JoinPath(kSavedModelAssetsDirectory, asset.filename()))); @@ -3570,7 +3569,6 @@ Status SavedModelSignatureDefImporter::LiftVariables() { pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass()); pm.addPass( mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession())); - pm.addPass(mlir::tf_saved_model::CreateDedupBoundInputBindingPass()); if (mlir::failed(pm.run(*module_))) return diag_handler.Combine(errors::Internal("Failed to lift variables."));