diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 904ccb7e820..17ed0e36a28 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -661,7 +661,9 @@ cc_library( ":tensorflow_types", ":translate_utils", "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:loader_util", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", @@ -673,6 +675,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/utils:transitive_fanin", + "//tensorflow/core/platform:protobuf_internal", "//tensorflow/core/platform:types", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", @@ -682,7 +685,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 140a778770c..6af70158e14 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -76,6 +76,23 @@ static LogicalResult Verify(GlobalTensorOp global_tensor) { return success(); } +static LogicalResult Verify(SessionInitializerOp session_initializer) { + mlir::SymbolTable symbol_table( + session_initializer.getParentOfType()); + + auto init_func_op = + symbol_table.lookup(session_initializer.initializer()); + if (!init_func_op) + return session_initializer.emitOpError() + << "the initializer function does not exist"; + + if (!init_func_op.getType().getResults().empty()) + return session_initializer.emitOpError() + << "the initializer function should have no output"; + + return success(); +} + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" @@ -220,6 +237,14 @@ static LogicalResult VerifySavedModelModule( } } } + + auto session_initializers = module.getOps(); + if (std::distance(session_initializers.begin(), session_initializers.end()) > + 1) { + return (*++session_initializers.begin()).emitError() + << "there must be no more than one session_initializer op"; + } + SymbolTable symbol_table(module); auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion()); if (!symbol_uses.hasValue()) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index 4431a160edf..497f4d90cb9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -128,4 +128,28 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> { let verifier = [{ return Verify(*this); }]; } +def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> { + let summary = "Initializes TensorFlow session state."; + let description = [{ + Represents a session initializer function initializes TensorFlow session + state. It is used to initialize resources in the saved model before calling + any exported functions. There must be no more than one session initializer + in a saved model. + + The `initializer` represents the initialization function. The function have + no output and this function should be only called once. + + This is used, for example, to initialize hash tables stored in resources and + accessed by resource name (rather than as resource handles or bound inputs + which is how `global_tensor`s are referenced) + }]; + + let arguments = (ins + FlatSymbolRefAttr:$initializer + ); + + + let verifier = [{ return Verify(*this); }]; +} + #endif // SAVED_MODEL_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py index 7171f63bb05..51ccbeb1fbd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -84,6 +84,7 @@ def do_test(signature_def_map, show_debug_info=False): builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map, + main_op=tf.tables_initializer(), strip_default_attrs=True) builder.save() diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py new file mode 100644 index 00000000000..64847434b82 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/hash_table_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# Verify that the tf.versions attribute exists. It is difficult to enforce +# contents, since the version numbers change over time. The conversion logic +# itself is verified in the common graphdef converter, so here just assert +# it is being invoked. +# CHECK: module +# CHECK-SAME: tf.versions +# CHECK-SAME: bad_consumers +# CHECK-SAME: min_consumer +# CHECK-SAME: producer + +# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> () +# CHECK: "tf_saved_model.global_tensor"() + +# CHECK: func {{@[a-zA-Z_0-9]+}}( +# CHECK-SAME: [[ARG0:%.*]]: tensor +# CHECK-SAME: [[ARG1:%.*]]: tensor () + // Representation for constants: (immutable) global tensor. // CHECK: tf_saved_model.global_tensor "tf_saved_model.global_tensor"() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index c055c6c9f56..544600cf6b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -258,3 +258,36 @@ module attributes {tf_saved_model.semantics} { // expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}} "tf_saved_model.global_tensor"() { sym_name = "v", type = tensor, value = dense<1.> : tensor<1xf32> } : () -> () } + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function does not exist}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function should have no output}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() -> tensor<1xf32> { + %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> + return %0 : tensor<1xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + // expected-error@+1 {{there must be no more than one session_initializer op}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() -> tensor<1xf32> { + %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> + return %0 : tensor<1xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 820d0ce31fb..3cff4217215 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -60,6 +60,8 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" @@ -99,6 +101,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -116,6 +119,7 @@ using mlir::NamedAttrList; using mlir::TensorType; using mlir::TF::VarHandleOp; using mlir::tf_saved_model::GlobalTensorOp; +using mlir::tf_saved_model::SessionInitializerOp; using stream_executor::port::StatusOr; namespace { @@ -2955,6 +2959,13 @@ void SortSavedModelModule(mlir::ModuleOp module) { named_global_tensor.global_tensor.getOperation()->moveBefore( &module.getBody()->front()); } + + auto initializers = module.getOps(); + if (!initializers.empty()) { + (*initializers.begin()) + .getOperation() + ->moveBefore(&module.getBody()->front()); + } } Status CreateSavedModelIR( @@ -3241,17 +3252,29 @@ class SavedModelSignatureDefImporter { absl::Span exported_names, mlir::MLIRContext* context) : bundle_(bundle), + flib_def_(OpRegistry::Global(), graph_def().library()), + debug_info_(), exported_names_(exported_names), - module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) { + // debug_info might not be loaded with loader_lite. + if (bundle_.debug_info != nullptr) debug_info_ = *bundle_.debug_info; + } // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function // for each signature. StatusOr ConvertSignatures(); - Status ConvertSignature(const GraphDef& graphdef, - const std::string& sig_def_key, - const SignatureDef& signature_def, - const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def); + Status ConvertSignature(const std::string& sig_def_key, + const SignatureDef& signature_def); + + // Converts the initialization graph in the SavedModel to an MLIR function. + Status ConvertInitializer(); + + // Converts a graph with feeds and fetches to an MLIR function. + StatusOr ConvertGraph( + const std::string& name, + const std::vector>& inputs, + const std::vector>& outputs, + const std::vector control_outputs); // Creates GlobalTensorOp for each variable and moves each VarHandle op to // the enclosing function's arguments. @@ -3273,18 +3296,62 @@ class SavedModelSignatureDefImporter { GraphImportConfig::InputArrays ParseInputArrays( const std::vector>& inputs); + const GraphDef& graph_def() const { + return bundle_.meta_graph_def.graph_def(); + } + const FunctionLibraryDefinition& flib_def() const { return flib_def_; } + const GraphDebugInfo& debug_info() const { return debug_info_; } + const SavedModelBundle& bundle_; + FunctionLibraryDefinition flib_def_; + GraphDebugInfo debug_info_; absl::Span exported_names_; mlir::OwningModuleRef module_; }; +Status SavedModelSignatureDefImporter::ConvertInitializer() { + std::vector asset_file_defs; + TF_RETURN_IF_ERROR( + internal::GetAssetFileDefs(bundle_.meta_graph_def, &asset_file_defs)); + + if (!asset_file_defs.empty()) + return errors::Unimplemented( + absl::StrCat("Assets are not supported in signaturedef importer")); + + std::string init_node_name; + TF_RETURN_IF_ERROR( + internal::GetInitOp("", bundle_.meta_graph_def, &init_node_name)); + + if (init_node_name.empty()) return Status::OK(); + + TF_ASSIGN_OR_RETURN(auto sub_module, + ConvertGraph(init_node_name, {}, {}, {init_node_name})); + + mlir::SymbolTable symbol_table(*sub_module); + + auto init_func_op = symbol_table.lookup(init_node_name); + + init_func_op.removeAttr("tf.entry_function"); + + mlir::OpBuilder builder(module_->getBodyRegion()); + + builder.create( + module_->getLoc(), builder.getSymbolRefAttr(init_func_op.getName())); + + // Move the converted functions to top level MLIR module. + auto* block = module_->getBody(); + auto* sub_block = sub_module->getBody(); + block->getOperations().splice( + mlir::Block::iterator(block->getTerminator()), sub_block->getOperations(), + sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator())); + + return Status::OK(); +} + StatusOr SavedModelSignatureDefImporter::ConvertSignatures() { const auto& signatures = bundle_.GetSignatures(); - const auto& graphdef = bundle_.meta_graph_def.graph_def(); - PopulateTfVersions(module_.get(), graphdef.versions()); - - FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library()); + PopulateTfVersions(module_.get(), graph_def().versions()); // debug_info might not be loaded with loader_lite. GraphDebugInfo debug_info; @@ -3307,9 +3374,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() { continue; } - TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, - debug_info, flib_def)); + TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def)); } + + TF_RETURN_IF_ERROR(ConvertInitializer()); TF_RETURN_IF_ERROR(LiftVariables()); mlir::OpBuilder builder(module_->getBodyRegion()); @@ -3320,10 +3388,32 @@ SavedModelSignatureDefImporter::ConvertSignatures() { return std::move(module_); } +StatusOr SavedModelSignatureDefImporter::ConvertGraph( + const std::string& name, + const std::vector>& inputs, + const std::vector>& outputs, + const std::vector control_outputs) { + GraphImportConfig specs; + specs.prune_unused_nodes = true; + specs.inputs = ParseInputArrays(inputs); + for (auto& output : outputs) specs.outputs.push_back(output.second.name()); + specs.control_outputs = control_outputs; + + // Convert sub-graphdef to sub-graph. + GraphConstructorOptions options; + options.allow_internal_ops = true; + options.add_default_attributes = true; + Graph graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, graph_def(), &graph)); + + // Convert sub-graph to MLIR module.true + return GraphDefImporter::Convert(module_->getContext(), graph, debug_info(), + flib_def(), specs, name); +} + Status SavedModelSignatureDefImporter::ConvertSignature( - const GraphDef& graphdef, const std::string& sig_def_key, - const SignatureDef& signature_def, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def) { + const std::string& sig_def_key, const SignatureDef& signature_def) { // Create local vectors for the input and output and sort them to be // deterministic. We don't want anyone to really depend on the order, client // should lookup argument/result mapping by attribute name. @@ -3339,34 +3429,9 @@ Status SavedModelSignatureDefImporter::ConvertSignature( return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first; }); - GraphImportConfig specs; - specs.prune_unused_nodes = true; - specs.inputs = ParseInputArrays(inputs); - for (auto& output : outputs) specs.outputs.push_back(output.second.name()); - - // Remove unused nodes and create sub-graphdef. - GraphDef sub_graph_def; - TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph( - graphdef, &sub_graph_def, - /*terminal_nodes=*/{specs.outputs.begin(), specs.outputs.end()})); - - // Set the function library definitions in the pruned graphdef. - *sub_graph_def.mutable_library() = flib_def.ToProto(); - - // Convert sub-graphdef to sub-graph. - GraphConstructorOptions options; - options.allow_internal_ops = true; - options.add_default_attributes = true; - Graph sub_graph(OpRegistry::Global()); - - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph)); - // Convert sub-graph to MLIR module. - TF_ASSIGN_OR_RETURN( - auto sub_module, - GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info, - flib_def, specs, sig_def_key)); + TF_ASSIGN_OR_RETURN(auto sub_module, + ConvertGraph(sig_def_key, inputs, outputs, {})); mlir::OpBuilder builder(sub_module->getBodyRegion()); // Find the FuncOp which corresponds to current SignatureDef.