Add compiler pass to remove duplicate 'tf_saved_model.bound_input' bindings.

Consolidate identical bound inputs so that resource variables do not alias in modules with tf_saved_model semantics.

PiperOrigin-RevId: 324288443
Change-Id: I4ccf9c19f3e2df123667b71560c3d3ae3c751913
This commit is contained in:
Rick Chao 2020-07-31 14:45:42 -07:00 committed by TensorFlower Gardener
parent 8a2c608cf7
commit e3ac0822d8
9 changed files with 6 additions and 145 deletions

View File

@ -676,7 +676,6 @@ cc_library(
cc_library(
name = "tf_saved_model_passes",
srcs = [
"transforms/deduplicate_bound_input_bindings.cc",
"transforms/freeze_global_tensors.cc",
"transforms/lift_variables_pass.cc",
"transforms/optimize_global_tensors.cc",

View File

@ -337,7 +337,6 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
i, "tf_saved_model.bound_input")) {
if (!unique_bound_inputs.insert(attr.getValue()).second) {
if (module.getAttr("tf_saved_model.under_construction")) continue;
return func.emitError()
<< "duplicate 'tf_saved_model.bound_input' binding";
}

View File

@ -27,15 +27,13 @@ import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"}
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"}
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset:.*]]"}
# CHECK: func [[init]]
# CHECK-SAME: [[ARG0:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset0]]}
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset1]]}
# CHECK-SAME: [[ARG:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset]]}
# CHECK-NEXT: [[R0:%.*]] = "tf.HashTableV2"()
# CHECK-SAME: shared_name = "[[hash_table:.*]]"
# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG0]])
# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG]])
def write_vocabulary_file(vocabulary):
@ -50,16 +48,11 @@ def write_vocabulary_file(vocabulary):
def test():
vocabulary_file = write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat'])
table_initializer = tf.lookup.TextFileInitializer(
vocabulary_file, tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64,
write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat']), tf.string,
tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64,
tf.lookup.TextFileIndex.LINE_NUMBER)
# Incur another bound_input on the asset, but with a different sym_name, i.e.,
# __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt.
table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10)
vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string,
name='asset_filepath')
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor)
x = tf.placeholder(tf.string, shape=(), name='input')
r = table.lookup(x)

View File

@ -1,33 +0,0 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-dedup-bound-input-binding-pass | FileCheck %s
module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} {
// Test case: Remove duplicate bound_input symbols.
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "w", type = tensor<f32>, value = dense<43.0> : tensor<f32> } : () -> ()
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "x", type = tensor<f32>, value = dense<44.0> : tensor<f32> } : () -> ()
// CHECK: func @f
// CHECK: %arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
// CHECK: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @w}
// CHECK: %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x}
// CHECK-NOT: %arg3
// CHECK-NOT: %arg4
func @f(
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @w},
%arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
%arg3: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x},
%arg4: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
) attributes {tf_saved_model.exported_names = ["f"]} {
// CHECK: "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
// CHECK: "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%val0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%val1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%val2 = "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%val3 = "tf.ReadVariableOp"(%arg3) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
%val4 = "tf.ReadVariableOp"(%arg4) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
return
}
}

View File

@ -76,16 +76,3 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
}
}
// -----
module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
// CHECK: func @f
func @f(
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
) attributes {tf_saved_model.exported_names = ["f"]} {
return
}
}

View File

@ -400,17 +400,3 @@ module attributes {tf_saved_model.semantics} {
}
}
// -----
module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
// expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}}
func @f(
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
) attributes {tf_saved_model.exported_names = ["f"]} {
return
}
}

View File

@ -1,65 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "llvm/ADT/DenseMap.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
namespace mlir {
namespace tf_saved_model {
namespace {
class DedupBoundInputBindingPass
: public PassWrapper<DedupBoundInputBindingPass, FunctionPass> {
public:
void runOnFunction() override;
};
void DedupBoundInputBindingPass::runOnFunction() {
FuncOp func = getFunction();
if (!mlir::tf_saved_model::IsExported(func)) return;
llvm::SmallDenseMap<Attribute, unsigned, 8> unique_bound_inputs;
llvm::SmallVector<unsigned, 8> arg_indices_to_erase;
for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
i, "tf_saved_model.bound_input");
if (!attr) continue;
auto inserted = unique_bound_inputs.insert(std::make_pair(attr, i));
if (inserted.second) continue;
auto duplicate_arg = func.getArgument(i);
auto original_arg = func.getArgument(unique_bound_inputs[attr]);
duplicate_arg.replaceAllUsesWith(original_arg);
arg_indices_to_erase.push_back(i);
}
func.eraseArguments(arg_indices_to_erase);
}
} // namespace
static PassRegistration<DedupBoundInputBindingPass> pass(
"tf-saved-model-dedup-bound-input-binding-pass",
"Remove duplicate 'tf_saved_model.bound_input' bindings.");
std::unique_ptr<OperationPass<FuncOp>> CreateDedupBoundInputBindingPass() {
return std::make_unique<DedupBoundInputBindingPass>();
}
} // namespace tf_saved_model
} // namespace mlir

View File

@ -46,9 +46,6 @@ CreateRemoveVariablesInSessionInitializerPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesPass(
::tensorflow::Session* session);
// Creates a pass that removes duplicate 'tf_saved_model.bound_input' bindings.
std::unique_ptr<OperationPass<ModuleOp>> CreateDedupBoundInputBindingPass();
} // namespace tf_saved_model
} // namespace mlir

View File

@ -3368,13 +3368,12 @@ SavedModelSignatureDefImporter::ConvertAssets() {
results.reserve(asset_file_defs.size());
mlir::OpBuilder builder(module_->getBodyRegion());
unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets.
for (const auto& asset : asset_file_defs) {
auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
module_->getLoc(),
/*sym_name=*/
builder.getStringAttr(
absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
absl::StrCat("__tf_saved_model_asset_", asset.filename())),
/*filename=*/
builder.getStringAttr(
io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
@ -3570,7 +3569,6 @@ Status SavedModelSignatureDefImporter::LiftVariables() {
pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
pm.addPass(
mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession()));
pm.addPass(mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
if (mlir::failed(pm.run(*module_)))
return diag_handler.Combine(errors::Internal("Failed to lift variables."));