Add a new pass that freezes SavedModel's AssetOp

This new pass will replace a func's saved model asset bound inputs which are
bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the
func's body.

closes 

PiperOrigin-RevId: 358304291
Change-Id: I2bd2f6fbcafc30c878a7848eb4c107c3b48d9673
This commit is contained in:
Jaesung Chung 2021-02-18 17:45:05 -08:00 committed by TensorFlower Gardener
parent 810e10d628
commit 0795d9d94c
11 changed files with 247 additions and 9 deletions

View File

@ -972,6 +972,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:core_cpu_base",
"//tensorflow/lite/toco:model_flags_proto_cc",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",

View File

@ -96,8 +96,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
}
return internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
result,
model_flags, toco_flags, std::move(module), pass_config,
/*saved_model_tags=*/{}, result,
/*session=*/llvm::None);
}

View File

@ -184,7 +184,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
// TODO(b/153507667): Pass the session object when importing logic is removed.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
toco_flags, std::move(module), pass_config, tags, result,
model_flags, toco_flags, std::move(module), pass_config, tags, result,
/*session=*/llvm::None);
return status;
}

View File

@ -288,8 +288,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
}
Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config,
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config,
const std::unordered_set<std::string>& saved_model_tags, string* result,
llvm::Optional<tensorflow::Session*> session) {
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
@ -311,7 +311,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
mlir::OpPassManager::Nesting::Implicit);
::tensorflow::SetCrashReproducer(pm);
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
tensorflow::AddTFToTFLConversionPasses(model_flags, pass_config, &pm,
session);
// Convert back to outlined while format for export back to flatbuffer.
if (pass_config.legalize_tf_while) {
pm.addPass(mlir::TFL::CreateWhileOutlinePass());

View File

@ -48,8 +48,8 @@ Status PopulateQuantizationSpecs(
// Convert imported MLIR file to TfLite flatbuffer.
// This will also run relevant passes as well.
Status ConvertMLIRToTFLiteFlatBuffer(
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
const mlir::TFL::PassConfig& pass_config,
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config,
const std::unordered_set<std::string>& saved_model_tags, string* result,
llvm::Optional<tensorflow::Session*> session);

View File

@ -62,7 +62,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
}
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session) {
mlir::TF::StandardPipelineOptions standard_pipeline_options;
@ -175,6 +176,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
/*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
}
if (!model_flags.saved_model_dir().empty()) {
// This pass 'freezes' tf saved model asset ops and inlines as string values
// in a format of the tf constant op.
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeAssetsPass(
model_flags.saved_model_dir()));
}
// The below passes only make sense if Builtin TFLite ops are enabled
// for emission.
if (pass_config.emit_builtin_tflite_ops) {
@ -253,6 +261,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
}
}
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session) {
const toco::ModelFlags model_flags;
AddTFToTFLConversionPasses(model_flags, pass_config, pass_manager, session);
}
} // namespace tensorflow
namespace mlir {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/lite/toco/model_flags.pb.h"
namespace tensorflow {
@ -28,6 +29,11 @@ namespace tensorflow {
// pass_manager. The session object will be provided when the TF MLIR is
// imported from saved model version one and utilized for capturing resource
// variables.
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session);
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
mlir::OpPassManager* pass_manager,
llvm::Optional<tensorflow::Session*> session);

View File

@ -788,6 +788,7 @@ cc_library(
srcs = [
"transforms/deduplicate_bound_input_bindings.cc",
"transforms/freeze_global_tensors.cc",
"transforms/freeze_saved_model_assets.cc",
"transforms/lift_variables_pass.cc",
"transforms/optimize_global_tensors.cc",
"transforms/remove_vars_in_session_initializer.cc",

View File

@ -0,0 +1,86 @@
// RUN: tf-opt -verify-diagnostics -tf-saved-model-freeze-assets -split-input-file %s | FileCheck %s
module attributes {tf_saved_model.semantics} {
// Test case: Basic freezing.
"tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> ()
// CHECK: func @f()
func @f(%arg0: tensor<!tf.string> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
// CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
// CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"()
// CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]])
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Sanity check handling of non-bound inputs.
// The pass shouldn't do anything in this case.
// CHECK: func @f(%arg0
func @f(%arg0: tensor<!tf.string> {tf_saved_model.index_path = [0]})
attributes {tf_saved_model.exported_names = ["f"]} {
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
// CHECK: "tf.InitializeTableFromTextFileV2"(%0, %arg0)
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Sanity check handling of non tf.InitializeTableFromTextFileV2 op usages.
"tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> ()
// CHECK: func @f(%arg0
func @f(%arg0: tensor<!tf.string> {tf_saved_model.bound_input = @v})
attributes {tf_saved_model.exported_names = ["f"]} {
"tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.string>) -> ()
return
}
func private @f_callee(%arg0: tensor<!tf.string>) {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
"tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> ()
"tf_saved_model.asset"() {filename = "assets/table2.txt", sym_name = "w"} : () -> ()
// CHECK: func @f()
func @f(%arg0: tensor<!tf.string> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.string> {tf_saved_model.bound_input = @w})
attributes {tf_saved_model.exported_names = ["f"]} {
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
%1 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
"tf.InitializeTableFromTextFileV2"(%1, %arg1) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
// CHECK: [[CST_1:%.+]] = "tf.Const"() {value = dense<"assets/table2.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
// CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
// CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"()
// CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]])
// CHECK: [[HASHTABLE_1:%.+]] = "tf.HashTableV2"()
// CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE_1]], [[CST_1]])
return
}
}
// -----
// Test running the pass on a module that does not have
// tf_saved_model.semantics.
module {}

View File

@ -0,0 +1,124 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <string>
#include <vector>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/core/platform/path.h"
namespace mlir {
namespace tf_saved_model {
namespace {
// This pass will replace a func's saved model asset bound inputs which are
// bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the
// func's body.
struct FreezeAssetsPass
: public PassWrapper<FreezeAssetsPass, OperationPass<ModuleOp>> {
FreezeAssetsPass() = default;
FreezeAssetsPass(const FreezeAssetsPass& pass) {}
explicit FreezeAssetsPass(std::string saved_model_dir) {
this->saved_model_dir = saved_model_dir;
}
void runOnOperation() override;
private:
std::string saved_model_dir;
};
void FreezeAssetsPass::runOnOperation() {
auto module = getOperation();
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
return;
}
SymbolTable symbol_table(module);
for (auto func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> args_to_erase;
OpBuilder builder(func.getBody());
for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
SmallVector<TF::InitializeTableFromTextFileV2Op, 4>
init_table_from_text_file_ops_to_erase;
auto asset = LookupBoundInputOfType<AssetOp>(func, i, symbol_table);
if (!asset) continue;
auto arg = func.getArgument(i);
bool arg_is_deletable = true;
for (auto user : arg.getUsers()) {
if (auto read_op =
llvm::dyn_cast<TF::InitializeTableFromTextFileV2Op>(user)) {
init_table_from_text_file_ops_to_erase.push_back(read_op);
} else {
arg_is_deletable = false;
continue;
}
}
if (arg_is_deletable) {
args_to_erase.push_back(i);
}
// Replace the arg with a tf.Const op in the function body.
builder.setInsertionPointToStart(&func.getBody().front());
std::string asset_filename = asset.filename().str();
std::string filename =
tensorflow::io::JoinPath(saved_model_dir, asset_filename);
ShapedType shaped_type =
RankedTensorType::get({1}, TF::StringType::get(builder.getContext()));
auto const_op = builder.create<TF::ConstOp>(
asset.getLoc(),
DenseStringElementsAttr::get(shaped_type, {filename}));
for (auto init_op : init_table_from_text_file_ops_to_erase) {
// Replace the InitializeTableFromTextFileV2Op to use the saved model's
// asset filepath.
builder.setInsertionPoint(init_op);
builder.create<TF::InitializeTableFromTextFileV2Op>(
init_op.getLoc(), init_op.table_handle(), const_op.getResult(),
init_op.key_index(), init_op.value_index(), init_op.vocab_size(),
init_op.delimiter());
init_op.erase();
}
}
func.eraseArguments(args_to_erase);
}
}
} // namespace
// For "opt" to pick up this pass.
static PassRegistration<FreezeAssetsPass> freeze_assets_pass(
"tf-saved-model-freeze-assets",
"Freeze tf_saved_model.asset's in func bodies.");
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeAssetsPass(
std::string saved_model_dir) {
return std::make_unique<FreezeAssetsPass>(saved_model_dir);
}
} // namespace tf_saved_model
} // namespace mlir

View File

@ -31,6 +31,10 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass();
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
bool allow_mutable_tensors = false);
// Creates a pass that freezes tf_saved_model.asset ops.
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeAssetsPass(
std::string saved_model_dir);
// Creates as pass that removes variables in the session initializer.
// This job is required with lifting variable passes. Originally, the session
// initializer function does assigning variables. However, the read-only