Add compiler pass to remove duplicate 'tf_saved_model.bound_input' bindings.
Consolidate identical bound inputs so that resource variables do not alias in modules with tf_saved_model semantics. PiperOrigin-RevId: 324288443 Change-Id: I4ccf9c19f3e2df123667b71560c3d3ae3c751913
This commit is contained in:
parent
8a2c608cf7
commit
e3ac0822d8
@ -676,7 +676,6 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tf_saved_model_passes",
|
||||
srcs = [
|
||||
"transforms/deduplicate_bound_input_bindings.cc",
|
||||
"transforms/freeze_global_tensors.cc",
|
||||
"transforms/lift_variables_pass.cc",
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
|
@ -337,7 +337,6 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
|
||||
if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
|
||||
i, "tf_saved_model.bound_input")) {
|
||||
if (!unique_bound_inputs.insert(attr.getValue()).second) {
|
||||
if (module.getAttr("tf_saved_model.under_construction")) continue;
|
||||
return func.emitError()
|
||||
<< "duplicate 'tf_saved_model.bound_input' binding";
|
||||
}
|
||||
|
@ -27,15 +27,13 @@ import tensorflow.compat.v1 as tf
|
||||
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||
|
||||
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
|
||||
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"}
|
||||
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"}
|
||||
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset:.*]]"}
|
||||
|
||||
# CHECK: func [[init]]
|
||||
# CHECK-SAME: [[ARG0:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset0]]}
|
||||
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset1]]}
|
||||
# CHECK-SAME: [[ARG:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset]]}
|
||||
# CHECK-NEXT: [[R0:%.*]] = "tf.HashTableV2"()
|
||||
# CHECK-SAME: shared_name = "[[hash_table:.*]]"
|
||||
# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG0]])
|
||||
# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG]])
|
||||
|
||||
|
||||
def write_vocabulary_file(vocabulary):
|
||||
@ -50,16 +48,11 @@ def write_vocabulary_file(vocabulary):
|
||||
|
||||
def test():
|
||||
|
||||
vocabulary_file = write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat'])
|
||||
table_initializer = tf.lookup.TextFileInitializer(
|
||||
vocabulary_file, tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64,
|
||||
write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat']), tf.string,
|
||||
tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64,
|
||||
tf.lookup.TextFileIndex.LINE_NUMBER)
|
||||
# Incur another bound_input on the asset, but with a different sym_name, i.e.,
|
||||
# __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt.
|
||||
table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10)
|
||||
vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string,
|
||||
name='asset_filepath')
|
||||
tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor)
|
||||
|
||||
x = tf.placeholder(tf.string, shape=(), name='input')
|
||||
r = table.lookup(x)
|
||||
|
@ -1,33 +0,0 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-dedup-bound-input-binding-pass | FileCheck %s
|
||||
|
||||
module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} {
|
||||
// Test case: Remove duplicate bound_input symbols.
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "w", type = tensor<f32>, value = dense<43.0> : tensor<f32> } : () -> ()
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "x", type = tensor<f32>, value = dense<44.0> : tensor<f32> } : () -> ()
|
||||
// CHECK: func @f
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @w}
|
||||
// CHECK: %arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x}
|
||||
// CHECK-NOT: %arg3
|
||||
// CHECK-NOT: %arg4
|
||||
func @f(
|
||||
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
|
||||
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @w},
|
||||
%arg2: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
|
||||
%arg3: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @x},
|
||||
%arg4: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
|
||||
) attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
// CHECK: "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
// CHECK: "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%val0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%val1 = "tf.ReadVariableOp"(%arg1) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%val2 = "tf.ReadVariableOp"(%arg2) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%val3 = "tf.ReadVariableOp"(%arg3) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
%val4 = "tf.ReadVariableOp"(%arg4) : (tensor<!tf.resource<tensor<f32>>>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
}
|
@ -76,16 +76,3 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} {
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
|
||||
// CHECK: func @f
|
||||
func @f(
|
||||
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
|
||||
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
|
||||
) attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -400,17 +400,3 @@ module attributes {tf_saved_model.semantics} {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<f32>, value = dense<42.0> : tensor<f32> } : () -> ()
|
||||
// expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}}
|
||||
func @f(
|
||||
%arg0: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v},
|
||||
%arg1: tensor<!tf.resource<tensor<f32>>> {tf_saved_model.bound_input = @v}
|
||||
) attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -1,65 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
namespace {
|
||||
|
||||
class DedupBoundInputBindingPass
|
||||
: public PassWrapper<DedupBoundInputBindingPass, FunctionPass> {
|
||||
public:
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
void DedupBoundInputBindingPass::runOnFunction() {
|
||||
FuncOp func = getFunction();
|
||||
if (!mlir::tf_saved_model::IsExported(func)) return;
|
||||
llvm::SmallDenseMap<Attribute, unsigned, 8> unique_bound_inputs;
|
||||
llvm::SmallVector<unsigned, 8> arg_indices_to_erase;
|
||||
for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) {
|
||||
auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
|
||||
i, "tf_saved_model.bound_input");
|
||||
if (!attr) continue;
|
||||
auto inserted = unique_bound_inputs.insert(std::make_pair(attr, i));
|
||||
if (inserted.second) continue;
|
||||
auto duplicate_arg = func.getArgument(i);
|
||||
auto original_arg = func.getArgument(unique_bound_inputs[attr]);
|
||||
duplicate_arg.replaceAllUsesWith(original_arg);
|
||||
arg_indices_to_erase.push_back(i);
|
||||
}
|
||||
func.eraseArguments(arg_indices_to_erase);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static PassRegistration<DedupBoundInputBindingPass> pass(
|
||||
"tf-saved-model-dedup-bound-input-binding-pass",
|
||||
"Remove duplicate 'tf_saved_model.bound_input' bindings.");
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDedupBoundInputBindingPass() {
|
||||
return std::make_unique<DedupBoundInputBindingPass>();
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
@ -46,9 +46,6 @@ CreateRemoveVariablesInSessionInitializerPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesPass(
|
||||
::tensorflow::Session* session);
|
||||
|
||||
// Creates a pass that removes duplicate 'tf_saved_model.bound_input' bindings.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateDedupBoundInputBindingPass();
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -3368,13 +3368,12 @@ SavedModelSignatureDefImporter::ConvertAssets() {
|
||||
results.reserve(asset_file_defs.size());
|
||||
|
||||
mlir::OpBuilder builder(module_->getBodyRegion());
|
||||
unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets.
|
||||
for (const auto& asset : asset_file_defs) {
|
||||
auto asset_op = builder.create<mlir::tf_saved_model::AssetOp>(
|
||||
module_->getLoc(),
|
||||
/*sym_name=*/
|
||||
builder.getStringAttr(
|
||||
absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())),
|
||||
absl::StrCat("__tf_saved_model_asset_", asset.filename())),
|
||||
/*filename=*/
|
||||
builder.getStringAttr(
|
||||
io::JoinPath(kSavedModelAssetsDirectory, asset.filename())));
|
||||
@ -3570,7 +3569,6 @@ Status SavedModelSignatureDefImporter::LiftVariables() {
|
||||
pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass());
|
||||
pm.addPass(
|
||||
mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession()));
|
||||
pm.addPass(mlir::tf_saved_model::CreateDedupBoundInputBindingPass());
|
||||
if (mlir::failed(pm.run(*module_)))
|
||||
return diag_handler.Combine(errors::Internal("Failed to lift variables."));
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user