Import initialization graph in SignatureDef SavedModels as an MLIR function in

TF saved model dialect.

PiperOrigin-RevId: 317137903
Change-Id: I7cbded06b3deafa30d3b3e3dad98cc8f056dd4e3
This commit is contained in:
Kuangyuan Chen 2020-06-18 10:59:05 -07:00 committed by TensorFlower Gardener
parent cf00e559d7
commit 2ff1c5a31b
8 changed files with 290 additions and 43 deletions

View File

@ -661,7 +661,9 @@ cc_library(
":tensorflow_types",
":translate_utils",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc/saved_model:loader_util",
"//tensorflow/compiler/jit:shape_inference_helpers",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
@ -673,6 +675,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/utils:transitive_fanin",
"//tensorflow/core/platform:protobuf_internal",
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
@ -682,7 +685,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",

View File

@ -76,6 +76,23 @@ static LogicalResult Verify(GlobalTensorOp global_tensor) {
return success();
}
static LogicalResult Verify(SessionInitializerOp session_initializer) {
mlir::SymbolTable symbol_table(
session_initializer.getParentOfType<ModuleOp>());
auto init_func_op =
symbol_table.lookup<mlir::FuncOp>(session_initializer.initializer());
if (!init_func_op)
return session_initializer.emitOpError()
<< "the initializer function does not exist";
if (!init_func_op.getType().getResults().empty())
return session_initializer.emitOpError()
<< "the initializer function should have no output";
return success();
}
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
@ -220,6 +237,14 @@ static LogicalResult VerifySavedModelModule(
}
}
}
auto session_initializers = module.getOps<SessionInitializerOp>();
if (std::distance(session_initializers.begin(), session_initializers.end()) >
1) {
return (*++session_initializers.begin()).emitError()
<< "there must be no more than one session_initializer op";
}
SymbolTable symbol_table(module);
auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
if (!symbol_uses.hasValue()) {

View File

@ -128,4 +128,28 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
let verifier = [{ return Verify(*this); }];
}
def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
let summary = "Initializes TensorFlow session state.";
let description = [{
Represents a session initializer function initializes TensorFlow session
state. It is used to initialize resources in the saved model before calling
any exported functions. There must be no more than one session initializer
in a saved model.
The `initializer` represents the initialization function. The function have
no output and this function should be only called once.
This is used, for example, to initialize hash tables stored in resources and
accessed by resource name (rather than as resource handles or bound inputs
which is how `global_tensor`s are referenced)
}];
let arguments = (ins
FlatSymbolRefAttr:$initializer
);
let verifier = [{ return Verify(*this); }];
}
#endif // SAVED_MODEL_DIALECT

View File

@ -84,6 +84,7 @@ def do_test(signature_def_map, show_debug_info=False):
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map,
main_op=tf.tables_initializer(),
strip_default_attrs=True)
builder.save()

View File

@ -0,0 +1,92 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/hash_table_v1 | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# Verify that the tf.versions attribute exists. It is difficult to enforce
# contents, since the version numbers change over time. The conversion logic
# itself is verified in the common graphdef converter, so here just assert
# it is being invoked.
# CHECK: module
# CHECK-SAME: tf.versions
# CHECK-SAME: bad_consumers
# CHECK-SAME: min_consumer
# CHECK-SAME: producer
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
# CHECK: "tf_saved_model.global_tensor"()
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: [[ARG0:%.*]]: tensor<i32>
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
# CHECK-NEXT: [[R0:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R1:%.*]] = "tf.HashTableV2"()
# CHECK-SAME: shared_name = "[[hash_table:.*]]"
# CHECK-NEXT: [[R2:%.*]] = "tf.LookupTableFindV2"([[R1]], [[ARG0]], [[R0]])
# CHECK-NEXT: [[R3:%.*]] = "tf.ReadVariableOp"([[ARG1]])
# CHECK-NEXT: [[R4:%.*]] = "tf.AddV2"([[R2]], [[R3]])
# CHECK-NEXT: return [[R4]]
# CHECK: func [[init]]
# CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R6:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"()
# CHECK-SAME: shared_name = "[[hash_table]]"
# CHECK-NEXT: "tf.LookupTableImportV2"([[R7]], [[R5]], [[R6]])
def Test():
z = tf.compat.v1.get_variable(
name='y',
shape=(),
initializer=tf.random_normal_initializer(),
trainable=True)
table_initializer = tf.lookup.KeyValueTensorInitializer(
keys=[1, 2, 3, 4],
values=[5, 6, 7, 8],
key_dtype=tf.int32,
value_dtype=tf.float32)
table = tf.lookup.StaticHashTable(
table_initializer, default_value=tf.constant(0.0))
x = tf.placeholder(tf.int32, shape=(), name='input')
y = table.lookup(x)
r = tf.add(y, z)
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
return {
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'r': tensor_info_r},
method_name='some_function'))
}
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(Test())

View File

@ -2,6 +2,11 @@
module attributes {tf_saved_model.semantics} {
// CHECK: tf_saved_model.session_initializer
"tf_saved_model.session_initializer"() {
initializer = @f
} : () -> ()
// Representation for constants: (immutable) global tensor.
// CHECK: tf_saved_model.global_tensor
"tf_saved_model.global_tensor"() {

View File

@ -258,3 +258,36 @@ module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function does not exist}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have no output}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> tensor<1xf32> {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
}
}
// -----
module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
// expected-error@+1 {{there must be no more than one session_initializer op}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> tensor<1xf32> {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
}
}

View File

@ -60,6 +60,8 @@ limitations under the License.
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
@ -99,6 +101,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@ -116,6 +119,7 @@ using mlir::NamedAttrList;
using mlir::TensorType;
using mlir::TF::VarHandleOp;
using mlir::tf_saved_model::GlobalTensorOp;
using mlir::tf_saved_model::SessionInitializerOp;
using stream_executor::port::StatusOr;
namespace {
@ -2955,6 +2959,13 @@ void SortSavedModelModule(mlir::ModuleOp module) {
named_global_tensor.global_tensor.getOperation()->moveBefore(
&module.getBody()->front());
}
auto initializers = module.getOps<SessionInitializerOp>();
if (!initializers.empty()) {
(*initializers.begin())
.getOperation()
->moveBefore(&module.getBody()->front());
}
}
Status CreateSavedModelIR(
@ -3241,17 +3252,29 @@ class SavedModelSignatureDefImporter {
absl::Span<std::string> exported_names,
mlir::MLIRContext* context)
: bundle_(bundle),
flib_def_(OpRegistry::Global(), graph_def().library()),
debug_info_(),
exported_names_(exported_names),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {
// debug_info might not be loaded with loader_lite.
if (bundle_.debug_info != nullptr) debug_info_ = *bundle_.debug_info;
}
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
// for each signature.
StatusOr<mlir::OwningModuleRef> ConvertSignatures();
Status ConvertSignature(const GraphDef& graphdef,
const std::string& sig_def_key,
const SignatureDef& signature_def,
const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def);
Status ConvertSignature(const std::string& sig_def_key,
const SignatureDef& signature_def);
// Converts the initialization graph in the SavedModel to an MLIR function.
Status ConvertInitializer();
// Converts a graph with feeds and fetches to an MLIR function.
StatusOr<mlir::OwningModuleRef> ConvertGraph(
const std::string& name,
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
const std::vector<std::string> control_outputs);
// Creates GlobalTensorOp for each variable and moves each VarHandle op to
// the enclosing function's arguments.
@ -3273,18 +3296,62 @@ class SavedModelSignatureDefImporter {
GraphImportConfig::InputArrays ParseInputArrays(
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
const GraphDef& graph_def() const {
return bundle_.meta_graph_def.graph_def();
}
const FunctionLibraryDefinition& flib_def() const { return flib_def_; }
const GraphDebugInfo& debug_info() const { return debug_info_; }
const SavedModelBundle& bundle_;
FunctionLibraryDefinition flib_def_;
GraphDebugInfo debug_info_;
absl::Span<std::string> exported_names_;
mlir::OwningModuleRef module_;
};
Status SavedModelSignatureDefImporter::ConvertInitializer() {
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
internal::GetAssetFileDefs(bundle_.meta_graph_def, &asset_file_defs));
if (!asset_file_defs.empty())
return errors::Unimplemented(
absl::StrCat("Assets are not supported in signaturedef importer"));
std::string init_node_name;
TF_RETURN_IF_ERROR(
internal::GetInitOp("", bundle_.meta_graph_def, &init_node_name));
if (init_node_name.empty()) return Status::OK();
TF_ASSIGN_OR_RETURN(auto sub_module,
ConvertGraph(init_node_name, {}, {}, {init_node_name}));
mlir::SymbolTable symbol_table(*sub_module);
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(init_node_name);
init_func_op.removeAttr("tf.entry_function");
mlir::OpBuilder builder(module_->getBodyRegion());
builder.create<mlir::tf_saved_model::SessionInitializerOp>(
module_->getLoc(), builder.getSymbolRefAttr(init_func_op.getName()));
// Move the converted functions to top level MLIR module.
auto* block = module_->getBody();
auto* sub_block = sub_module->getBody();
block->getOperations().splice(
mlir::Block::iterator(block->getTerminator()), sub_block->getOperations(),
sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator()));
return Status::OK();
}
StatusOr<mlir::OwningModuleRef>
SavedModelSignatureDefImporter::ConvertSignatures() {
const auto& signatures = bundle_.GetSignatures();
const auto& graphdef = bundle_.meta_graph_def.graph_def();
PopulateTfVersions(module_.get(), graphdef.versions());
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library());
PopulateTfVersions(module_.get(), graph_def().versions());
// debug_info might not be loaded with loader_lite.
GraphDebugInfo debug_info;
@ -3307,9 +3374,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
continue;
}
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
debug_info, flib_def));
TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
}
TF_RETURN_IF_ERROR(ConvertInitializer());
TF_RETURN_IF_ERROR(LiftVariables());
mlir::OpBuilder builder(module_->getBodyRegion());
@ -3320,10 +3388,32 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
return std::move(module_);
}
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefImporter::ConvertGraph(
const std::string& name,
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
const std::vector<std::string> control_outputs) {
GraphImportConfig specs;
specs.prune_unused_nodes = true;
specs.inputs = ParseInputArrays(inputs);
for (auto& output : outputs) specs.outputs.push_back(output.second.name());
specs.control_outputs = control_outputs;
// Convert sub-graphdef to sub-graph.
GraphConstructorOptions options;
options.allow_internal_ops = true;
options.add_default_attributes = true;
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, graph_def(), &graph));
// Convert sub-graph to MLIR module.true
return GraphDefImporter::Convert(module_->getContext(), graph, debug_info(),
flib_def(), specs, name);
}
Status SavedModelSignatureDefImporter::ConvertSignature(
const GraphDef& graphdef, const std::string& sig_def_key,
const SignatureDef& signature_def, const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def) {
const std::string& sig_def_key, const SignatureDef& signature_def) {
// Create local vectors for the input and output and sort them to be
// deterministic. We don't want anyone to really depend on the order, client
// should lookup argument/result mapping by attribute name.
@ -3339,34 +3429,9 @@ Status SavedModelSignatureDefImporter::ConvertSignature(
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
});
GraphImportConfig specs;
specs.prune_unused_nodes = true;
specs.inputs = ParseInputArrays(inputs);
for (auto& output : outputs) specs.outputs.push_back(output.second.name());
// Remove unused nodes and create sub-graphdef.
GraphDef sub_graph_def;
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
graphdef, &sub_graph_def,
/*terminal_nodes=*/{specs.outputs.begin(), specs.outputs.end()}));
// Set the function library definitions in the pruned graphdef.
*sub_graph_def.mutable_library() = flib_def.ToProto();
// Convert sub-graphdef to sub-graph.
GraphConstructorOptions options;
options.allow_internal_ops = true;
options.add_default_attributes = true;
Graph sub_graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph));
// Convert sub-graph to MLIR module.
TF_ASSIGN_OR_RETURN(
auto sub_module,
GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info,
flib_def, specs, sig_def_key));
TF_ASSIGN_OR_RETURN(auto sub_module,
ConvertGraph(sig_def_key, inputs, outputs, {}));
mlir::OpBuilder builder(sub_module->getBodyRegion());
// Find the FuncOp which corresponds to current SignatureDef.