Import initialization graph in SignatureDef SavedModels as an MLIR function in
TF saved model dialect. PiperOrigin-RevId: 317137903 Change-Id: I7cbded06b3deafa30d3b3e3dad98cc8f056dd4e3
This commit is contained in:
parent
cf00e559d7
commit
2ff1c5a31b
|
@ -661,7 +661,9 @@ cc_library(
|
||||||
":tensorflow_types",
|
":tensorflow_types",
|
||||||
":translate_utils",
|
":translate_utils",
|
||||||
"//tensorflow/cc/saved_model:bundle_v2",
|
"//tensorflow/cc/saved_model:bundle_v2",
|
||||||
|
"//tensorflow/cc/saved_model:constants",
|
||||||
"//tensorflow/cc/saved_model:loader_lite",
|
"//tensorflow/cc/saved_model:loader_lite",
|
||||||
|
"//tensorflow/cc/saved_model:loader_util",
|
||||||
"//tensorflow/compiler/jit:shape_inference_helpers",
|
"//tensorflow/compiler/jit:shape_inference_helpers",
|
||||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
|
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
|
||||||
|
@ -673,6 +675,7 @@ cc_library(
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||||
|
"//tensorflow/core/platform:protobuf_internal",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
@ -682,7 +685,6 @@ cc_library(
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Analysis",
|
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
|
|
@ -76,6 +76,23 @@ static LogicalResult Verify(GlobalTensorOp global_tensor) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static LogicalResult Verify(SessionInitializerOp session_initializer) {
|
||||||
|
mlir::SymbolTable symbol_table(
|
||||||
|
session_initializer.getParentOfType<ModuleOp>());
|
||||||
|
|
||||||
|
auto init_func_op =
|
||||||
|
symbol_table.lookup<mlir::FuncOp>(session_initializer.initializer());
|
||||||
|
if (!init_func_op)
|
||||||
|
return session_initializer.emitOpError()
|
||||||
|
<< "the initializer function does not exist";
|
||||||
|
|
||||||
|
if (!init_func_op.getType().getResults().empty())
|
||||||
|
return session_initializer.emitOpError()
|
||||||
|
<< "the initializer function should have no output";
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
|
||||||
|
|
||||||
|
@ -220,6 +237,14 @@ static LogicalResult VerifySavedModelModule(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto session_initializers = module.getOps<SessionInitializerOp>();
|
||||||
|
if (std::distance(session_initializers.begin(), session_initializers.end()) >
|
||||||
|
1) {
|
||||||
|
return (*++session_initializers.begin()).emitError()
|
||||||
|
<< "there must be no more than one session_initializer op";
|
||||||
|
}
|
||||||
|
|
||||||
SymbolTable symbol_table(module);
|
SymbolTable symbol_table(module);
|
||||||
auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
|
auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
|
||||||
if (!symbol_uses.hasValue()) {
|
if (!symbol_uses.hasValue()) {
|
||||||
|
|
|
@ -128,4 +128,28 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
|
||||||
let verifier = [{ return Verify(*this); }];
|
let verifier = [{ return Verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
|
||||||
|
let summary = "Initializes TensorFlow session state.";
|
||||||
|
let description = [{
|
||||||
|
Represents a session initializer function initializes TensorFlow session
|
||||||
|
state. It is used to initialize resources in the saved model before calling
|
||||||
|
any exported functions. There must be no more than one session initializer
|
||||||
|
in a saved model.
|
||||||
|
|
||||||
|
The `initializer` represents the initialization function. The function have
|
||||||
|
no output and this function should be only called once.
|
||||||
|
|
||||||
|
This is used, for example, to initialize hash tables stored in resources and
|
||||||
|
accessed by resource name (rather than as resource handles or bound inputs
|
||||||
|
which is how `global_tensor`s are referenced)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
FlatSymbolRefAttr:$initializer
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
}
|
||||||
|
|
||||||
#endif // SAVED_MODEL_DIALECT
|
#endif // SAVED_MODEL_DIALECT
|
||||||
|
|
|
@ -84,6 +84,7 @@ def do_test(signature_def_map, show_debug_info=False):
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, [tf.saved_model.tag_constants.SERVING],
|
sess, [tf.saved_model.tag_constants.SERVING],
|
||||||
signature_def_map,
|
signature_def_map,
|
||||||
|
main_op=tf.tables_initializer(),
|
||||||
strip_default_attrs=True)
|
strip_default_attrs=True)
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# RUN: %p/hash_table_v1 | FileCheck %s
|
||||||
|
|
||||||
|
# pylint: disable=missing-docstring,line-too-long
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||||
|
|
||||||
|
# Verify that the tf.versions attribute exists. It is difficult to enforce
|
||||||
|
# contents, since the version numbers change over time. The conversion logic
|
||||||
|
# itself is verified in the common graphdef converter, so here just assert
|
||||||
|
# it is being invoked.
|
||||||
|
# CHECK: module
|
||||||
|
# CHECK-SAME: tf.versions
|
||||||
|
# CHECK-SAME: bad_consumers
|
||||||
|
# CHECK-SAME: min_consumer
|
||||||
|
# CHECK-SAME: producer
|
||||||
|
|
||||||
|
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
|
||||||
|
# CHECK: "tf_saved_model.global_tensor"()
|
||||||
|
|
||||||
|
# CHECK: func {{@[a-zA-Z_0-9]+}}(
|
||||||
|
# CHECK-SAME: [[ARG0:%.*]]: tensor<i32>
|
||||||
|
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource
|
||||||
|
# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"]
|
||||||
|
|
||||||
|
# CHECK-NEXT: [[R0:%.*]] = "tf.Const"()
|
||||||
|
# CHECK-NEXT: [[R1:%.*]] = "tf.HashTableV2"()
|
||||||
|
# CHECK-SAME: shared_name = "[[hash_table:.*]]"
|
||||||
|
# CHECK-NEXT: [[R2:%.*]] = "tf.LookupTableFindV2"([[R1]], [[ARG0]], [[R0]])
|
||||||
|
# CHECK-NEXT: [[R3:%.*]] = "tf.ReadVariableOp"([[ARG1]])
|
||||||
|
# CHECK-NEXT: [[R4:%.*]] = "tf.AddV2"([[R2]], [[R3]])
|
||||||
|
# CHECK-NEXT: return [[R4]]
|
||||||
|
|
||||||
|
# CHECK: func [[init]]
|
||||||
|
# CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
|
||||||
|
# CHECK-NEXT: [[R6:%.*]] = "tf.Const"()
|
||||||
|
# CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"()
|
||||||
|
# CHECK-SAME: shared_name = "[[hash_table]]"
|
||||||
|
# CHECK-NEXT: "tf.LookupTableImportV2"([[R7]], [[R5]], [[R6]])
|
||||||
|
|
||||||
|
|
||||||
|
def Test():
|
||||||
|
|
||||||
|
z = tf.compat.v1.get_variable(
|
||||||
|
name='y',
|
||||||
|
shape=(),
|
||||||
|
initializer=tf.random_normal_initializer(),
|
||||||
|
trainable=True)
|
||||||
|
table_initializer = tf.lookup.KeyValueTensorInitializer(
|
||||||
|
keys=[1, 2, 3, 4],
|
||||||
|
values=[5, 6, 7, 8],
|
||||||
|
key_dtype=tf.int32,
|
||||||
|
value_dtype=tf.float32)
|
||||||
|
table = tf.lookup.StaticHashTable(
|
||||||
|
table_initializer, default_value=tf.constant(0.0))
|
||||||
|
|
||||||
|
x = tf.placeholder(tf.int32, shape=(), name='input')
|
||||||
|
y = table.lookup(x)
|
||||||
|
r = tf.add(y, z)
|
||||||
|
|
||||||
|
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
|
||||||
|
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||||
|
inputs={'x': tensor_info_x},
|
||||||
|
outputs={'r': tensor_info_r},
|
||||||
|
method_name='some_function'))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
common_v1.set_tf_options()
|
||||||
|
common_v1.do_test(Test())
|
|
@ -2,6 +2,11 @@
|
||||||
|
|
||||||
module attributes {tf_saved_model.semantics} {
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
|
// CHECK: tf_saved_model.session_initializer
|
||||||
|
"tf_saved_model.session_initializer"() {
|
||||||
|
initializer = @f
|
||||||
|
} : () -> ()
|
||||||
|
|
||||||
// Representation for constants: (immutable) global tensor.
|
// Representation for constants: (immutable) global tensor.
|
||||||
// CHECK: tf_saved_model.global_tensor
|
// CHECK: tf_saved_model.global_tensor
|
||||||
"tf_saved_model.global_tensor"() {
|
"tf_saved_model.global_tensor"() {
|
||||||
|
|
|
@ -258,3 +258,36 @@ module attributes {tf_saved_model.semantics} {
|
||||||
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
|
// expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}}
|
||||||
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
"tf_saved_model.global_tensor"() { sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
|
// expected-error@+1 {{the initializer function does not exist}}
|
||||||
|
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
|
// expected-error@+1 {{the initializer function should have no output}}
|
||||||
|
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||||
|
func @init() -> tensor<1xf32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||||
|
return %0 : tensor<1xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf_saved_model.semantics} {
|
||||||
|
|
||||||
|
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||||
|
// expected-error@+1 {{there must be no more than one session_initializer op}}
|
||||||
|
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||||
|
func @init() -> tensor<1xf32> {
|
||||||
|
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||||
|
return %0 : tensor<1xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -60,6 +60,8 @@ limitations under the License.
|
||||||
#include "mlir/IR/Types.h" // from @llvm-project
|
#include "mlir/IR/Types.h" // from @llvm-project
|
||||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||||
|
#include "tensorflow/cc/saved_model/constants.h"
|
||||||
|
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
|
||||||
|
@ -99,6 +101,7 @@ limitations under the License.
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
|
@ -116,6 +119,7 @@ using mlir::NamedAttrList;
|
||||||
using mlir::TensorType;
|
using mlir::TensorType;
|
||||||
using mlir::TF::VarHandleOp;
|
using mlir::TF::VarHandleOp;
|
||||||
using mlir::tf_saved_model::GlobalTensorOp;
|
using mlir::tf_saved_model::GlobalTensorOp;
|
||||||
|
using mlir::tf_saved_model::SessionInitializerOp;
|
||||||
using stream_executor::port::StatusOr;
|
using stream_executor::port::StatusOr;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -2955,6 +2959,13 @@ void SortSavedModelModule(mlir::ModuleOp module) {
|
||||||
named_global_tensor.global_tensor.getOperation()->moveBefore(
|
named_global_tensor.global_tensor.getOperation()->moveBefore(
|
||||||
&module.getBody()->front());
|
&module.getBody()->front());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto initializers = module.getOps<SessionInitializerOp>();
|
||||||
|
if (!initializers.empty()) {
|
||||||
|
(*initializers.begin())
|
||||||
|
.getOperation()
|
||||||
|
->moveBefore(&module.getBody()->front());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CreateSavedModelIR(
|
Status CreateSavedModelIR(
|
||||||
|
@ -3241,17 +3252,29 @@ class SavedModelSignatureDefImporter {
|
||||||
absl::Span<std::string> exported_names,
|
absl::Span<std::string> exported_names,
|
||||||
mlir::MLIRContext* context)
|
mlir::MLIRContext* context)
|
||||||
: bundle_(bundle),
|
: bundle_(bundle),
|
||||||
|
flib_def_(OpRegistry::Global(), graph_def().library()),
|
||||||
|
debug_info_(),
|
||||||
exported_names_(exported_names),
|
exported_names_(exported_names),
|
||||||
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {
|
||||||
|
// debug_info might not be loaded with loader_lite.
|
||||||
|
if (bundle_.debug_info != nullptr) debug_info_ = *bundle_.debug_info;
|
||||||
|
}
|
||||||
|
|
||||||
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
// Converts the SavedModel to the SavedModel dialect. Creates an MLIR function
|
||||||
// for each signature.
|
// for each signature.
|
||||||
StatusOr<mlir::OwningModuleRef> ConvertSignatures();
|
StatusOr<mlir::OwningModuleRef> ConvertSignatures();
|
||||||
Status ConvertSignature(const GraphDef& graphdef,
|
Status ConvertSignature(const std::string& sig_def_key,
|
||||||
const std::string& sig_def_key,
|
const SignatureDef& signature_def);
|
||||||
const SignatureDef& signature_def,
|
|
||||||
const GraphDebugInfo& debug_info,
|
// Converts the initialization graph in the SavedModel to an MLIR function.
|
||||||
const FunctionLibraryDefinition& flib_def);
|
Status ConvertInitializer();
|
||||||
|
|
||||||
|
// Converts a graph with feeds and fetches to an MLIR function.
|
||||||
|
StatusOr<mlir::OwningModuleRef> ConvertGraph(
|
||||||
|
const std::string& name,
|
||||||
|
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
|
||||||
|
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
|
||||||
|
const std::vector<std::string> control_outputs);
|
||||||
|
|
||||||
// Creates GlobalTensorOp for each variable and moves each VarHandle op to
|
// Creates GlobalTensorOp for each variable and moves each VarHandle op to
|
||||||
// the enclosing function's arguments.
|
// the enclosing function's arguments.
|
||||||
|
@ -3273,18 +3296,62 @@ class SavedModelSignatureDefImporter {
|
||||||
GraphImportConfig::InputArrays ParseInputArrays(
|
GraphImportConfig::InputArrays ParseInputArrays(
|
||||||
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
const std::vector<std::pair<std::string, TensorInfo>>& inputs);
|
||||||
|
|
||||||
|
const GraphDef& graph_def() const {
|
||||||
|
return bundle_.meta_graph_def.graph_def();
|
||||||
|
}
|
||||||
|
const FunctionLibraryDefinition& flib_def() const { return flib_def_; }
|
||||||
|
const GraphDebugInfo& debug_info() const { return debug_info_; }
|
||||||
|
|
||||||
const SavedModelBundle& bundle_;
|
const SavedModelBundle& bundle_;
|
||||||
|
FunctionLibraryDefinition flib_def_;
|
||||||
|
GraphDebugInfo debug_info_;
|
||||||
absl::Span<std::string> exported_names_;
|
absl::Span<std::string> exported_names_;
|
||||||
mlir::OwningModuleRef module_;
|
mlir::OwningModuleRef module_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Status SavedModelSignatureDefImporter::ConvertInitializer() {
|
||||||
|
std::vector<AssetFileDef> asset_file_defs;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
internal::GetAssetFileDefs(bundle_.meta_graph_def, &asset_file_defs));
|
||||||
|
|
||||||
|
if (!asset_file_defs.empty())
|
||||||
|
return errors::Unimplemented(
|
||||||
|
absl::StrCat("Assets are not supported in signaturedef importer"));
|
||||||
|
|
||||||
|
std::string init_node_name;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
internal::GetInitOp("", bundle_.meta_graph_def, &init_node_name));
|
||||||
|
|
||||||
|
if (init_node_name.empty()) return Status::OK();
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto sub_module,
|
||||||
|
ConvertGraph(init_node_name, {}, {}, {init_node_name}));
|
||||||
|
|
||||||
|
mlir::SymbolTable symbol_table(*sub_module);
|
||||||
|
|
||||||
|
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(init_node_name);
|
||||||
|
|
||||||
|
init_func_op.removeAttr("tf.entry_function");
|
||||||
|
|
||||||
|
mlir::OpBuilder builder(module_->getBodyRegion());
|
||||||
|
|
||||||
|
builder.create<mlir::tf_saved_model::SessionInitializerOp>(
|
||||||
|
module_->getLoc(), builder.getSymbolRefAttr(init_func_op.getName()));
|
||||||
|
|
||||||
|
// Move the converted functions to top level MLIR module.
|
||||||
|
auto* block = module_->getBody();
|
||||||
|
auto* sub_block = sub_module->getBody();
|
||||||
|
block->getOperations().splice(
|
||||||
|
mlir::Block::iterator(block->getTerminator()), sub_block->getOperations(),
|
||||||
|
sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef>
|
StatusOr<mlir::OwningModuleRef>
|
||||||
SavedModelSignatureDefImporter::ConvertSignatures() {
|
SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||||
const auto& signatures = bundle_.GetSignatures();
|
const auto& signatures = bundle_.GetSignatures();
|
||||||
const auto& graphdef = bundle_.meta_graph_def.graph_def();
|
PopulateTfVersions(module_.get(), graph_def().versions());
|
||||||
PopulateTfVersions(module_.get(), graphdef.versions());
|
|
||||||
|
|
||||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library());
|
|
||||||
|
|
||||||
// debug_info might not be loaded with loader_lite.
|
// debug_info might not be loaded with loader_lite.
|
||||||
GraphDebugInfo debug_info;
|
GraphDebugInfo debug_info;
|
||||||
|
@ -3307,9 +3374,10 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def,
|
TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def));
|
||||||
debug_info, flib_def));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ConvertInitializer());
|
||||||
TF_RETURN_IF_ERROR(LiftVariables());
|
TF_RETURN_IF_ERROR(LiftVariables());
|
||||||
|
|
||||||
mlir::OpBuilder builder(module_->getBodyRegion());
|
mlir::OpBuilder builder(module_->getBodyRegion());
|
||||||
|
@ -3320,10 +3388,32 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||||
return std::move(module_);
|
return std::move(module_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefImporter::ConvertGraph(
|
||||||
|
const std::string& name,
|
||||||
|
const std::vector<std::pair<std::string, TensorInfo>>& inputs,
|
||||||
|
const std::vector<std::pair<std::string, TensorInfo>>& outputs,
|
||||||
|
const std::vector<std::string> control_outputs) {
|
||||||
|
GraphImportConfig specs;
|
||||||
|
specs.prune_unused_nodes = true;
|
||||||
|
specs.inputs = ParseInputArrays(inputs);
|
||||||
|
for (auto& output : outputs) specs.outputs.push_back(output.second.name());
|
||||||
|
specs.control_outputs = control_outputs;
|
||||||
|
|
||||||
|
// Convert sub-graphdef to sub-graph.
|
||||||
|
GraphConstructorOptions options;
|
||||||
|
options.allow_internal_ops = true;
|
||||||
|
options.add_default_attributes = true;
|
||||||
|
Graph graph(OpRegistry::Global());
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, graph_def(), &graph));
|
||||||
|
|
||||||
|
// Convert sub-graph to MLIR module.true
|
||||||
|
return GraphDefImporter::Convert(module_->getContext(), graph, debug_info(),
|
||||||
|
flib_def(), specs, name);
|
||||||
|
}
|
||||||
|
|
||||||
Status SavedModelSignatureDefImporter::ConvertSignature(
|
Status SavedModelSignatureDefImporter::ConvertSignature(
|
||||||
const GraphDef& graphdef, const std::string& sig_def_key,
|
const std::string& sig_def_key, const SignatureDef& signature_def) {
|
||||||
const SignatureDef& signature_def, const GraphDebugInfo& debug_info,
|
|
||||||
const FunctionLibraryDefinition& flib_def) {
|
|
||||||
// Create local vectors for the input and output and sort them to be
|
// Create local vectors for the input and output and sort them to be
|
||||||
// deterministic. We don't want anyone to really depend on the order, client
|
// deterministic. We don't want anyone to really depend on the order, client
|
||||||
// should lookup argument/result mapping by attribute name.
|
// should lookup argument/result mapping by attribute name.
|
||||||
|
@ -3339,34 +3429,9 @@ Status SavedModelSignatureDefImporter::ConvertSignature(
|
||||||
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
|
return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first;
|
||||||
});
|
});
|
||||||
|
|
||||||
GraphImportConfig specs;
|
|
||||||
specs.prune_unused_nodes = true;
|
|
||||||
specs.inputs = ParseInputArrays(inputs);
|
|
||||||
for (auto& output : outputs) specs.outputs.push_back(output.second.name());
|
|
||||||
|
|
||||||
// Remove unused nodes and create sub-graphdef.
|
|
||||||
GraphDef sub_graph_def;
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
|
|
||||||
graphdef, &sub_graph_def,
|
|
||||||
/*terminal_nodes=*/{specs.outputs.begin(), specs.outputs.end()}));
|
|
||||||
|
|
||||||
// Set the function library definitions in the pruned graphdef.
|
|
||||||
*sub_graph_def.mutable_library() = flib_def.ToProto();
|
|
||||||
|
|
||||||
// Convert sub-graphdef to sub-graph.
|
|
||||||
GraphConstructorOptions options;
|
|
||||||
options.allow_internal_ops = true;
|
|
||||||
options.add_default_attributes = true;
|
|
||||||
Graph sub_graph(OpRegistry::Global());
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph));
|
|
||||||
|
|
||||||
// Convert sub-graph to MLIR module.
|
// Convert sub-graph to MLIR module.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(auto sub_module,
|
||||||
auto sub_module,
|
ConvertGraph(sig_def_key, inputs, outputs, {}));
|
||||||
GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info,
|
|
||||||
flib_def, specs, sig_def_key));
|
|
||||||
mlir::OpBuilder builder(sub_module->getBodyRegion());
|
mlir::OpBuilder builder(sub_module->getBodyRegion());
|
||||||
|
|
||||||
// Find the FuncOp which corresponds to current SignatureDef.
|
// Find the FuncOp which corresponds to current SignatureDef.
|
||||||
|
|
Loading…
Reference in New Issue