Add a new pass that freezes SavedModel's AssetOp
This new pass will replace a func's saved model asset bound inputs which are bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the func's body. closes #46737 PiperOrigin-RevId: 358304291 Change-Id: I2bd2f6fbcafc30c878a7848eb4c107c3b48d9673
This commit is contained in:
parent
810e10d628
commit
0795d9d94c
tensorflow/compiler/mlir
@ -972,6 +972,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/lite/toco:model_flags_proto_cc",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -96,8 +96,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
}
|
||||
|
||||
return internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{},
|
||||
result,
|
||||
model_flags, toco_flags, std::move(module), pass_config,
|
||||
/*saved_model_tags=*/{}, result,
|
||||
/*session=*/llvm::None);
|
||||
}
|
||||
|
||||
|
@ -184,7 +184,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
|
||||
|
||||
// TODO(b/153507667): Pass the session object when importing logic is removed.
|
||||
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
|
||||
toco_flags, std::move(module), pass_config, tags, result,
|
||||
model_flags, toco_flags, std::move(module), pass_config, tags, result,
|
||||
/*session=*/llvm::None);
|
||||
return status;
|
||||
}
|
||||
|
@ -288,8 +288,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) {
|
||||
}
|
||||
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config,
|
||||
const std::unordered_set<std::string>& saved_model_tags, string* result,
|
||||
llvm::Optional<tensorflow::Session*> session) {
|
||||
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
|
||||
@ -311,7 +311,8 @@ Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
mlir::OpPassManager::Nesting::Implicit);
|
||||
::tensorflow::SetCrashReproducer(pm);
|
||||
|
||||
tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session);
|
||||
tensorflow::AddTFToTFLConversionPasses(model_flags, pass_config, &pm,
|
||||
session);
|
||||
// Convert back to outlined while format for export back to flatbuffer.
|
||||
if (pass_config.legalize_tf_while) {
|
||||
pm.addPass(mlir::TFL::CreateWhileOutlinePass());
|
||||
|
@ -48,8 +48,8 @@ Status PopulateQuantizationSpecs(
|
||||
// Convert imported MLIR file to TfLite flatbuffer.
|
||||
// This will also run relevant passes as well.
|
||||
Status ConvertMLIRToTFLiteFlatBuffer(
|
||||
const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
|
||||
mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config,
|
||||
const std::unordered_set<std::string>& saved_model_tags, string* result,
|
||||
llvm::Optional<tensorflow::Session*> session);
|
||||
|
||||
|
@ -62,7 +62,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs,
|
||||
mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
|
||||
}
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager,
|
||||
llvm::Optional<tensorflow::Session*> session) {
|
||||
mlir::TF::StandardPipelineOptions standard_pipeline_options;
|
||||
@ -175,6 +176,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
/*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
|
||||
}
|
||||
|
||||
if (!model_flags.saved_model_dir().empty()) {
|
||||
// This pass 'freezes' tf saved model asset ops and inlines as string values
|
||||
// in a format of the tf constant op.
|
||||
pass_manager->addPass(mlir::tf_saved_model::CreateFreezeAssetsPass(
|
||||
model_flags.saved_model_dir()));
|
||||
}
|
||||
|
||||
// The below passes only make sense if Builtin TFLite ops are enabled
|
||||
// for emission.
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
@ -253,6 +261,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
}
|
||||
}
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager,
|
||||
llvm::Optional<tensorflow::Session*> session) {
|
||||
const toco::ModelFlags model_flags;
|
||||
AddTFToTFLConversionPasses(model_flags, pass_config, pass_manager, session);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace mlir {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/lite/toco/model_flags.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -28,6 +29,11 @@ namespace tensorflow {
|
||||
// pass_manager. The session object will be provided when the TF MLIR is
|
||||
// imported from saved model version one and utilized for capturing resource
|
||||
// variables.
|
||||
void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags,
|
||||
const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager,
|
||||
llvm::Optional<tensorflow::Session*> session);
|
||||
|
||||
void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
mlir::OpPassManager* pass_manager,
|
||||
llvm::Optional<tensorflow::Session*> session);
|
||||
|
@ -788,6 +788,7 @@ cc_library(
|
||||
srcs = [
|
||||
"transforms/deduplicate_bound_input_bindings.cc",
|
||||
"transforms/freeze_global_tensors.cc",
|
||||
"transforms/freeze_saved_model_assets.cc",
|
||||
"transforms/lift_variables_pass.cc",
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
"transforms/remove_vars_in_session_initializer.cc",
|
||||
|
@ -0,0 +1,86 @@
|
||||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-freeze-assets -split-input-file %s | FileCheck %s
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Basic freezing.
|
||||
|
||||
"tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> ()
|
||||
|
||||
// CHECK: func @f()
|
||||
func @f(%arg0: tensor<!tf.string> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
// CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
|
||||
// CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"()
|
||||
// CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Sanity check handling of non-bound inputs.
|
||||
// The pass shouldn't do anything in this case.
|
||||
|
||||
// CHECK: func @f(%arg0
|
||||
func @f(%arg0: tensor<!tf.string> {tf_saved_model.index_path = [0]})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
// CHECK: "tf.InitializeTableFromTextFileV2"(%0, %arg0)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Sanity check handling of non tf.InitializeTableFromTextFileV2 op usages.
|
||||
|
||||
"tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> ()
|
||||
|
||||
// CHECK: func @f(%arg0
|
||||
func @f(%arg0: tensor<!tf.string> {tf_saved_model.bound_input = @v})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
"tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor<!tf.string>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func private @f_callee(%arg0: tensor<!tf.string>) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
"tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> ()
|
||||
"tf_saved_model.asset"() {filename = "assets/table2.txt", sym_name = "w"} : () -> ()
|
||||
|
||||
// CHECK: func @f()
|
||||
func @f(%arg0: tensor<!tf.string> {tf_saved_model.bound_input = @v}, %arg1: tensor<!tf.string> {tf_saved_model.bound_input = @w})
|
||||
attributes {tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
%1 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor<!tf.resource>
|
||||
"tf.InitializeTableFromTextFileV2"(%1, %arg1) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor<!tf.resource>, tensor<!tf.string>) -> ()
|
||||
// CHECK: [[CST_1:%.+]] = "tf.Const"() {value = dense<"assets/table2.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
|
||||
// CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string>
|
||||
// CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"()
|
||||
// CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]])
|
||||
// CHECK: [[HASHTABLE_1:%.+]] = "tf.HashTableV2"()
|
||||
// CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE_1]], [[CST_1]])
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test running the pass on a module that does not have
|
||||
// tf_saved_model.semantics.
|
||||
module {}
|
@ -0,0 +1,124 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
namespace {
|
||||
|
||||
// This pass will replace a func's saved model asset bound inputs which are
|
||||
// bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the
|
||||
// func's body.
|
||||
struct FreezeAssetsPass
|
||||
: public PassWrapper<FreezeAssetsPass, OperationPass<ModuleOp>> {
|
||||
FreezeAssetsPass() = default;
|
||||
|
||||
FreezeAssetsPass(const FreezeAssetsPass& pass) {}
|
||||
explicit FreezeAssetsPass(std::string saved_model_dir) {
|
||||
this->saved_model_dir = saved_model_dir;
|
||||
}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
std::string saved_model_dir;
|
||||
};
|
||||
|
||||
void FreezeAssetsPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
|
||||
return;
|
||||
}
|
||||
SymbolTable symbol_table(module);
|
||||
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
SmallVector<unsigned, 4> args_to_erase;
|
||||
OpBuilder builder(func.getBody());
|
||||
|
||||
for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
|
||||
SmallVector<TF::InitializeTableFromTextFileV2Op, 4>
|
||||
init_table_from_text_file_ops_to_erase;
|
||||
auto asset = LookupBoundInputOfType<AssetOp>(func, i, symbol_table);
|
||||
|
||||
if (!asset) continue;
|
||||
|
||||
auto arg = func.getArgument(i);
|
||||
bool arg_is_deletable = true;
|
||||
for (auto user : arg.getUsers()) {
|
||||
if (auto read_op =
|
||||
llvm::dyn_cast<TF::InitializeTableFromTextFileV2Op>(user)) {
|
||||
init_table_from_text_file_ops_to_erase.push_back(read_op);
|
||||
} else {
|
||||
arg_is_deletable = false;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (arg_is_deletable) {
|
||||
args_to_erase.push_back(i);
|
||||
}
|
||||
|
||||
// Replace the arg with a tf.Const op in the function body.
|
||||
builder.setInsertionPointToStart(&func.getBody().front());
|
||||
|
||||
std::string asset_filename = asset.filename().str();
|
||||
std::string filename =
|
||||
tensorflow::io::JoinPath(saved_model_dir, asset_filename);
|
||||
ShapedType shaped_type =
|
||||
RankedTensorType::get({1}, TF::StringType::get(builder.getContext()));
|
||||
auto const_op = builder.create<TF::ConstOp>(
|
||||
asset.getLoc(),
|
||||
DenseStringElementsAttr::get(shaped_type, {filename}));
|
||||
for (auto init_op : init_table_from_text_file_ops_to_erase) {
|
||||
// Replace the InitializeTableFromTextFileV2Op to use the saved model's
|
||||
// asset filepath.
|
||||
builder.setInsertionPoint(init_op);
|
||||
builder.create<TF::InitializeTableFromTextFileV2Op>(
|
||||
init_op.getLoc(), init_op.table_handle(), const_op.getResult(),
|
||||
init_op.key_index(), init_op.value_index(), init_op.vocab_size(),
|
||||
init_op.delimiter());
|
||||
init_op.erase();
|
||||
}
|
||||
}
|
||||
func.eraseArguments(args_to_erase);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// For "opt" to pick up this pass.
|
||||
static PassRegistration<FreezeAssetsPass> freeze_assets_pass(
|
||||
"tf-saved-model-freeze-assets",
|
||||
"Freeze tf_saved_model.asset's in func bodies.");
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeAssetsPass(
|
||||
std::string saved_model_dir) {
|
||||
return std::make_unique<FreezeAssetsPass>(saved_model_dir);
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
@ -31,6 +31,10 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(
|
||||
bool allow_mutable_tensors = false);
|
||||
|
||||
// Creates a pass that freezes tf_saved_model.asset ops.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeAssetsPass(
|
||||
std::string saved_model_dir);
|
||||
|
||||
// Creates as pass that removes variables in the session initializer.
|
||||
// This job is required with lifting variable passes. Originally, the session
|
||||
// initializer function does assigning variables. However, the read-only
|
||||
|
Loading…
Reference in New Issue
Block a user