From 0795d9d94c661301a1a89918476e133e0f05334c Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Thu, 18 Feb 2021 17:45:05 -0800 Subject: [PATCH] Add a new pass that freezes SavedModel's AssetOp This new pass will replace a func's saved model asset bound inputs which are bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the func's body. closes #46737 PiperOrigin-RevId: 358304291 Change-Id: I2bd2f6fbcafc30c878a7848eb4c107c3b48d9673 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../lite/python/graphdef_to_tfl_flatbuffer.cc | 4 +- .../python/saved_model_to_tfl_flatbuffer.cc | 2 +- .../lite/python/tf_tfl_flatbuffer_helpers.cc | 7 +- .../lite/python/tf_tfl_flatbuffer_helpers.h | 4 +- .../compiler/mlir/lite/tf_tfl_passes.cc | 17 ++- tensorflow/compiler/mlir/lite/tf_tfl_passes.h | 6 + tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../tests/tf_saved_model_freeze_assets.mlir | 86 ++++++++++++ .../transforms/freeze_saved_model_assets.cc | 124 ++++++++++++++++++ .../transforms/tf_saved_model_passes.h | 4 + 11 files changed, 247 insertions(+), 9 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 06fa87fe35a..4941048dded 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -972,6 +972,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:core_cpu_base", + "//tensorflow/lite/toco:model_flags_proto_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 357948b1e77..995495ed20e 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -96,8 +96,8 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, } return internal::ConvertMLIRToTFLiteFlatBuffer( - toco_flags, std::move(module), pass_config, /*saved_model_tags=*/{}, - result, + model_flags, toco_flags, std::move(module), pass_config, + /*saved_model_tags=*/{}, result, /*session=*/llvm::None); } diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 4f3becaa1a0..a0d93b8d55c 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -184,7 +184,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, // TODO(b/153507667): Pass the session object when importing logic is removed. auto status = internal::ConvertMLIRToTFLiteFlatBuffer( - toco_flags, std::move(module), pass_config, tags, result, + model_flags, toco_flags, std::move(module), pass_config, tags, result, /*session=*/llvm::None); return status; } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 213186f23c3..edde83c046d 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -288,8 +288,8 @@ Status DumpOpGraphToFile(mlir::ModuleOp module, const std::string& filename) { } Status ConvertMLIRToTFLiteFlatBuffer( - const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - const mlir::TFL::PassConfig& pass_config, + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, string* result, llvm::Optional session) { bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops(); @@ -311,7 +311,8 @@ Status ConvertMLIRToTFLiteFlatBuffer( mlir::OpPassManager::Nesting::Implicit); ::tensorflow::SetCrashReproducer(pm); - tensorflow::AddTFToTFLConversionPasses(pass_config, &pm, session); + tensorflow::AddTFToTFLConversionPasses(model_flags, pass_config, &pm, + session); // Convert back to outlined while format for export back to flatbuffer. if (pass_config.legalize_tf_while) { pm.addPass(mlir::TFL::CreateWhileOutlinePass()); diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index abe1761fea0..7627012a1b0 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -48,8 +48,8 @@ Status PopulateQuantizationSpecs( // Convert imported MLIR file to TfLite flatbuffer. // This will also run relevant passes as well. Status ConvertMLIRToTFLiteFlatBuffer( - const toco::TocoFlags& toco_flags, mlir::OwningModuleRef module, - const mlir::TFL::PassConfig& pass_config, + const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags, + mlir::OwningModuleRef module, const mlir::TFL::PassConfig& pass_config, const std::unordered_set& saved_model_tags, string* result, llvm::Optional session); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 510f49ed41f..6f847f1c34d 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -62,7 +62,8 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); } -void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, +void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags, + const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager, llvm::Optional session) { mlir::TF::StandardPipelineOptions standard_pipeline_options; @@ -175,6 +176,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, /*allow_mutable_tensors=*/pass_config.enable_tflite_variables)); } + if (!model_flags.saved_model_dir().empty()) { + // This pass 'freezes' tf saved model asset ops and inlines as string values + // in a format of the tf constant op. + pass_manager->addPass(mlir::tf_saved_model::CreateFreezeAssetsPass( + model_flags.saved_model_dir())); + } + // The below passes only make sense if Builtin TFLite ops are enabled // for emission. if (pass_config.emit_builtin_tflite_ops) { @@ -253,6 +261,13 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, } } +void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager* pass_manager, + llvm::Optional session) { + const toco::ModelFlags model_flags; + AddTFToTFLConversionPasses(model_flags, pass_config, pass_manager, session); +} + } // namespace tensorflow namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h index 3a5027c3179..8104508a99f 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/core/public/session.h" +#include "tensorflow/lite/toco/model_flags.pb.h" namespace tensorflow { @@ -28,6 +29,11 @@ namespace tensorflow { // pass_manager. The session object will be provided when the TF MLIR is // imported from saved model version one and utilized for capturing resource // variables. +void AddTFToTFLConversionPasses(const toco::ModelFlags& model_flags, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager* pass_manager, + llvm::Optional session); + void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager* pass_manager, llvm::Optional session); diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 028288a7143..dd785afb95f 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -788,6 +788,7 @@ cc_library( srcs = [ "transforms/deduplicate_bound_input_bindings.cc", "transforms/freeze_global_tensors.cc", + "transforms/freeze_saved_model_assets.cc", "transforms/lift_variables_pass.cc", "transforms/optimize_global_tensors.cc", "transforms/remove_vars_in_session_initializer.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir new file mode 100644 index 00000000000..2daa223586d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_assets.mlir @@ -0,0 +1,86 @@ +// RUN: tf-opt -verify-diagnostics -tf-saved-model-freeze-assets -split-input-file %s | FileCheck %s + +module attributes {tf_saved_model.semantics} { + + // Test case: Basic freezing. + + "tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> () + + // CHECK: func @f() + func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor, tensor) -> () + // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string> + // CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"() + // CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]]) + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // Test case: Sanity check handling of non-bound inputs. + // The pass shouldn't do anything in this case. + + // CHECK: func @f(%arg0 + func @f(%arg0: tensor {tf_saved_model.index_path = [0]}) + attributes {tf_saved_model.exported_names = ["f"]} { + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor, tensor) -> () + // CHECK: "tf.InitializeTableFromTextFileV2"(%0, %arg0) + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // Test case: Sanity check handling of non tf.InitializeTableFromTextFileV2 op usages. + + "tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> () + + // CHECK: func @f(%arg0 + func @f(%arg0: tensor {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee} : (tensor) -> () + return + } + + func private @f_callee(%arg0: tensor) { + return + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> () + "tf_saved_model.asset"() {filename = "assets/table2.txt", sym_name = "w"} : () -> () + + // CHECK: func @f() + func @f(%arg0: tensor {tf_saved_model.bound_input = @v}, %arg1: tensor {tf_saved_model.bound_input = @w}) + attributes {tf_saved_model.exported_names = ["f"]} { + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor, tensor) -> () + %1 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%1, %arg1) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor, tensor) -> () + // CHECK: [[CST_1:%.+]] = "tf.Const"() {value = dense<"assets/table2.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string> + // CHECK: [[CST:%.+]] = "tf.Const"() {value = dense<"assets/table.txt"> : tensor<1x!tf.string>} : () -> tensor<1x!tf.string> + // CHECK: [[HASHTABLE:%.+]] = "tf.HashTableV2"() + // CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE]], [[CST]]) + // CHECK: [[HASHTABLE_1:%.+]] = "tf.HashTableV2"() + // CHECK: "tf.InitializeTableFromTextFileV2"([[HASHTABLE_1]], [[CST_1]]) + return + } +} + +// ----- + +// Test running the pass on a module that does not have +// tf_saved_model.semantics. +module {} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc new file mode 100644 index 00000000000..999cb7552ba --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc @@ -0,0 +1,124 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/core/platform/path.h" + +namespace mlir { +namespace tf_saved_model { +namespace { + +// This pass will replace a func's saved model asset bound inputs which are +// bound to tf.InitializeTableFromTextFileV2Op ops with tf.Const ops inside the +// func's body. +struct FreezeAssetsPass + : public PassWrapper> { + FreezeAssetsPass() = default; + + FreezeAssetsPass(const FreezeAssetsPass& pass) {} + explicit FreezeAssetsPass(std::string saved_model_dir) { + this->saved_model_dir = saved_model_dir; + } + + void runOnOperation() override; + + private: + std::string saved_model_dir; +}; + +void FreezeAssetsPass::runOnOperation() { + auto module = getOperation(); + if (!tf_saved_model::HasTfSavedModelSemantics(module)) { + return; + } + SymbolTable symbol_table(module); + + for (auto func : module.getOps()) { + SmallVector args_to_erase; + OpBuilder builder(func.getBody()); + + for (int i = 0, e = func.getNumArguments(); i < e; ++i) { + SmallVector + init_table_from_text_file_ops_to_erase; + auto asset = LookupBoundInputOfType(func, i, symbol_table); + + if (!asset) continue; + + auto arg = func.getArgument(i); + bool arg_is_deletable = true; + for (auto user : arg.getUsers()) { + if (auto read_op = + llvm::dyn_cast(user)) { + init_table_from_text_file_ops_to_erase.push_back(read_op); + } else { + arg_is_deletable = false; + continue; + } + } + if (arg_is_deletable) { + args_to_erase.push_back(i); + } + + // Replace the arg with a tf.Const op in the function body. + builder.setInsertionPointToStart(&func.getBody().front()); + + std::string asset_filename = asset.filename().str(); + std::string filename = + tensorflow::io::JoinPath(saved_model_dir, asset_filename); + ShapedType shaped_type = + RankedTensorType::get({1}, TF::StringType::get(builder.getContext())); + auto const_op = builder.create( + asset.getLoc(), + DenseStringElementsAttr::get(shaped_type, {filename})); + for (auto init_op : init_table_from_text_file_ops_to_erase) { + // Replace the InitializeTableFromTextFileV2Op to use the saved model's + // asset filepath. + builder.setInsertionPoint(init_op); + builder.create( + init_op.getLoc(), init_op.table_handle(), const_op.getResult(), + init_op.key_index(), init_op.value_index(), init_op.vocab_size(), + init_op.delimiter()); + init_op.erase(); + } + } + func.eraseArguments(args_to_erase); + } +} + +} // namespace + +// For "opt" to pick up this pass. +static PassRegistration freeze_assets_pass( + "tf-saved-model-freeze-assets", + "Freeze tf_saved_model.asset's in func bodies."); + +std::unique_ptr> CreateFreezeAssetsPass( + std::string saved_model_dir) { + return std::make_unique(saved_model_dir); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h index 52a039355d8..e52f2c3d8e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h @@ -31,6 +31,10 @@ std::unique_ptr> CreateOptimizeGlobalTensorsPass(); std::unique_ptr> CreateFreezeGlobalTensorsPass( bool allow_mutable_tensors = false); +// Creates a pass that freezes tf_saved_model.asset ops. +std::unique_ptr> CreateFreezeAssetsPass( + std::string saved_model_dir); + // Creates as pass that removes variables in the session initializer. // This job is required with lifting variable passes. Originally, the session // initializer function does assigning variables. However, the read-only