Initial version of SavedModel V1 Importer that converts a V1 SavedModel to a

MLIR Module that contains functions specified by signature defs.

PiperOrigin-RevId: 288042933
Change-Id: I5dfde397eb8635020025aa1dc6fee690e4b45ae3
This commit is contained in:
A. Unique TensorFlower 2020-01-03 13:46:45 -08:00 committed by TensorFlower Gardener
parent 05dd398ea5
commit d25dd80748
12 changed files with 646 additions and 8 deletions

View File

@ -108,6 +108,45 @@ string ExperimentalConvertSavedModelToMlir(
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info); return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
} }
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
//
// Args:
// saved_model_path: File path from which to load the SavedModel.
// tags: Tags to identify MetaGraphDef that need to be loaded.
//
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
string ExperimentalConvertSavedModelV1ToMlir(
const string &saved_model_path,
const string &tags,
bool show_debug_info,
TF_Status* status) {
// Load the saved model into a SavedModelBundle.
std::unordered_set<string> tag_set
= absl::StrSplit(tags, ',', absl::SkipEmpty());
tensorflow::SavedModelBundle bundle;
auto load_status = tensorflow::LoadSavedModel(
{}, {},
saved_model_path, tag_set, &bundle);
if (!load_status.ok()) {
Set_TF_Status_from_Status(status, load_status);
return "// error";
}
// Convert the SavedModelBundle to an MLIR module.
mlir::MLIRContext context;
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
}
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
}
string ExperimentalRunPassPipeline( string ExperimentalRunPassPipeline(
const string &mlir_txt, const string &mlir_txt,
@ -154,6 +193,7 @@ string ExperimentalRunPassPipeline(
%unignore tensorflow::swig; %unignore tensorflow::swig;
%unignore tensorflow::swig::ImportGraphDef; %unignore tensorflow::swig::ImportGraphDef;
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir; %unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir;
%unignore tensorflow::swig::ExperimentalRunPassPipeline; %unignore tensorflow::swig::ExperimentalRunPassPipeline;
// Wrap this function // Wrap this function
@ -167,6 +207,11 @@ static string ExperimentalConvertSavedModelToMlir(
const string &exported_names, const string &exported_names,
bool show_debug_info, bool show_debug_info,
TF_Status* status); TF_Status* status);
static string ExperimentalConvertSavedModelV1ToMlir(
const string &saved_model_path,
const string &tags,
bool show_debug_info,
TF_Status* status);
static string ExperimentalRunPassPipeline( static string ExperimentalRunPassPipeline(
const string &mlir_txt, const string &mlir_txt,
const string &pass_pipeline, const string &pass_pipeline,
@ -188,6 +233,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path,
show_debug_info show_debug_info
).decode('utf-8'); ).decode('utf-8');
def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
tags, show_debug_info):
return ExperimentalConvertSavedModelV1ToMlir(
str(saved_model_path).encode('utf-8'),
str(tags).encode('utf-8'),
show_debug_info
).decode('utf-8');
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info): def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
return ExperimentalRunPassPipeline( return ExperimentalRunPassPipeline(
mlir_txt.encode('utf-8'), mlir_txt.encode('utf-8'),

View File

@ -348,15 +348,18 @@ cc_library(
":tensorflow", ":tensorflow",
":tensorflow_passes", ":tensorflow_passes",
"//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/jit:shape_inference_helpers",
"//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir:op_or_arg_name_mapper",
"//tensorflow/compiler/tf2xla:functionalize_control_flow", "//tensorflow/compiler/tf2xla:functionalize_control_flow",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/utils:transitive_fanin",
"//tensorflow/core/platform:types", "//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",

View File

@ -13,6 +13,15 @@ py_library(
], ],
) )
py_library(
name = "common_v1",
srcs = ["common_v1.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
filegroup( filegroup(
name = "test_utilities", name = "test_utilities",
testonly = True, testonly = True,
@ -24,7 +33,10 @@ filegroup(
# Drop trailing ".py" from all test file names. # Drop trailing ".py" from all test file names.
all_test_basenames = [py[:-3] for py in glob( all_test_basenames = [py[:-3] for py in glob(
["*.py"], ["*.py"],
exclude = ["common.py"], exclude = [
"common.py",
"common_v1.py",
],
)] )]
# Instantiate all the tests. # Instantiate all the tests.

View File

@ -0,0 +1,64 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/basic_v1 | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
# CHECK: func @basic([[ARG0:%.*]]: tensor<3x1xf32>,
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32>
# CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor<!tf.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
# CHECK-NEXT: return [[R1]] : tensor<3x3xf32>
def Test():
# Default TF1.x uses reference variables that are not supported by SavedModel
# v1 Importer. To use SavedModel V1 Importer, resource variables should be
# enabled.
tf.compat.v1.enable_resource_variables()
tf.compat.v1.disable_eager_execution()
x = tf.constant([[1.0], [1.0], [1.0]])
y = tf.compat.v1.get_variable(
name='y',
shape=(1, 3),
initializer=tf.random_normal_initializer(),
trainable=True)
r = tf.matmul(x, y)
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
return {
'basic':
(tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'r': tensor_info_r},
method_name=tf.saved_model.PREDICT_METHOD_NAME))
}
if __name__ == '__main__':
common_v1.do_test(Test())

View File

@ -11,6 +11,7 @@ def tf_saved_model_test(name, data):
srcs = [name + ".py"], srcs = [name + ".py"],
deps = [ deps = [
"//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common", "//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common",
"//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common_v1",
], ],
) )

View File

@ -0,0 +1,93 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Serves as a common "main" function for all the SavedModel tests.
There is a fair amount of setup needed to initialize tensorflow and get it
into a proper TF2 execution mode. This hides that boilerplate.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import app
from absl import flags
from absl import logging
import tensorflow.compat.v1 as tf
from tensorflow.python import pywrap_tensorflow
# Use /tmp to make debugging the tests easier (see README.md)
flags.DEFINE_string('save_model_path', '', 'Path to save the model to.')
FLAGS = flags.FLAGS
# This function needs to take a "create_module_fn", as opposed to just the
# module itself, because the creation of the module has to be delayed until
# after absl and tensorflow have run various initialization steps.
def do_test(signature_def_map, show_debug_info=False):
"""Runs test.
1. Performs absl and tf "main"-like initialization that must run before almost
anything else.
2. Converts signature_def_map to SavedModel V1
3. Converts SavedModel V1 to MLIR
4. Prints the textual MLIR to stdout (it is expected that the caller will have
FileCheck checks in its file to check this output).
This is only for use by the MLIR SavedModel importer tests.
Args:
signature_def_map: A map from string key to signature_def. The key will be
used as function name in the resulting MLIR.
show_debug_info: If true, shows debug locations in the resulting MLIR.
"""
# Make LOG(ERROR) in C++ code show up on the console.
# All `Status` passed around in the C++ API seem to eventually go into
# `LOG(ERROR)`, so this makes them print out by default.
logging.set_stderrthreshold('error')
def app_main(argv):
"""Function passed to absl.app.run."""
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
if FLAGS.save_model_path:
save_model_path = FLAGS.save_model_path
else:
save_model_path = tempfile.mktemp(suffix='.saved_model')
sess = tf.Session()
sess.run(tf.initializers.global_variables())
builder = tf.saved_model.builder.SavedModelBuilder(save_model_path)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map,
strip_default_attrs=True)
builder.save()
logging.info('Saved model to: %s', save_model_path)
mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
# The canonicalization shouldn't affect these tests in any way.
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
mlir, 'tf-standard-pipeline', show_debug_info)
print(mlir)
app.run(app_main)

View File

@ -0,0 +1,64 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/shared_variable_v1 | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
# CHECK: func {{@.*}}([[ARG0:%.*]]: tensor<3x1xf32>,
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32>
# CHECK: func {{@.*}}([[ARG2:%.*]]: tensor<3x1xf32>,
# CHECK-SAME: [[ARG3:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32>
def Test():
# Default TF1.x uses reference variables that are not supported by SavedModel
# v1 Importer. To use SavedModel V1 Importer, resource variables should be
# enabled.
tf.enable_resource_variables()
tf.compat.v1.disable_eager_execution()
x = tf.constant([[1.0], [1.0], [1.0]])
y = tf.get_variable(
name='y',
shape=(1, 3),
initializer=tf.random_normal_initializer(),
trainable=True)
r = tf.matmul(x, y)
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_r = tf.saved_model.utils.build_tensor_info(r)
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'r': tensor_info_r},
method_name=tf.saved_model.PREDICT_METHOD_NAME)
# Create two signatures that share the same variable.
return {'basic': signature_def, 'basic_2': signature_def}
if __name__ == '__main__':
common_v1.do_test(Test())

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
@ -71,6 +72,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
@ -81,6 +83,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/protobuf.h"
@ -1734,8 +1737,8 @@ class GraphDefImporter : public ImporterBase {
static StatusOr<mlir::OwningModuleRef> Convert( static StatusOr<mlir::OwningModuleRef> Convert(
mlir::MLIRContext* context, const Graph& graph, mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
const GraphImportConfig& specs); llvm::StringRef func_name);
private: private:
explicit GraphDefImporter( explicit GraphDefImporter(
@ -1773,7 +1776,7 @@ class GraphDefImporter : public ImporterBase {
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert( StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
mlir::MLIRContext* context, const Graph& graph, mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
const GraphImportConfig& specs) { const GraphImportConfig& specs, llvm::StringRef func_name) {
mlir::OwningModuleRef module = mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name; std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
@ -1861,7 +1864,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
{producer, min_consumer, bad_consumers}))); {producer, min_consumer, bad_consumers})));
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
"main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
resource_arg_unique_ids)); resource_arg_unique_ids));
return module; return module;
} }
@ -2771,6 +2774,292 @@ StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert(
return module; return module;
} }
// A helper class to import a TensorFlow model expressed in SavedModel V1 into
// an MLIR Module.
class SavedModelV1Importer {
public:
// Main entry point: converts all functions (specified by SignatureDefs) in
// the given meta graph to an MLIR Module.
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
mlir::MLIRContext* context) {
SavedModelV1Importer importer(bundle, context);
return importer.ConvertSignatures();
}
private:
SavedModelV1Importer(const SavedModelBundle& bundle,
mlir::MLIRContext* context)
: bundle_(bundle),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
// Convert the SavedModel to TF Executor Dialect. It creates a MLIR function
// for each signature.
StatusOr<mlir::OwningModuleRef> ConvertSignatures();
StatusOr<mlir::OwningModuleRef> ConvertSignature(
const GraphImportConfig& specs, llvm::StringRef func_name,
const SignatureDef& signature_def, const GraphDef& sub_graph_def,
const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def);
// Create GlobalTensorOp for each variable and move each VarHandle op to
// the enclosing function's arugments.
Status LiftVariables();
void LiftVariable(mlir::TF::VarHandleOp op);
// Read all variables from the SavedModel through session, and create
// GlobalTensorOp for these variables.
Status ReadVariablesFromSession(
const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops);
GraphImportConfig::InputArrays ParseInputArrays(
const tensorflow::protobuf::Map<std::string, TensorInfo>& inputs);
std::vector<std::string> ParseOutputArrays(
const tensorflow::protobuf::Map<std::string, TensorInfo>& outputs);
const SavedModelBundle& bundle_;
mlir::OwningModuleRef module_;
};
// Convert the SavedModel to TF Executor Dialect. It creates a MLIR function
// for each signature.
StatusOr<mlir::OwningModuleRef> SavedModelV1Importer::ConvertSignatures() {
const auto& signatures = bundle_.GetSignatures();
const auto& graphdef = bundle_.meta_graph_def.graph_def();
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library());
// debug_info might not be loaded with loader_lite.
GraphDebugInfo debug_info;
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
for (const auto& key_and_signature_def : signatures) {
const auto& func_name = key_and_signature_def.first;
const auto& signature_def = key_and_signature_def.second;
GraphImportConfig specs;
specs.inputs = ParseInputArrays(signature_def.inputs());
specs.outputs = ParseOutputArrays(signature_def.outputs());
// Remove unused nodes and create a sub graphdef.
GraphDef sub_graph_def;
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
graphdef, &sub_graph_def,
/* terminal_nodes = */ {specs.outputs.begin(), specs.outputs.end()}));
auto status_or_sub_module = ConvertSignature(
specs, func_name, signature_def, sub_graph_def, debug_info, flib_def);
if (!status_or_sub_module.ok()) {
LOG(ERROR) << "Failed to convert SignatureDef for " << func_name << ": "
<< status_or_sub_module.status();
continue;
}
auto& sub_module = status_or_sub_module.ValueOrDie();
// Move the converted functions to top level MLIR module.
auto* block = module_->getBody();
auto* sub_block = sub_module->getBody();
block->getOperations().splice(
mlir::Block::iterator(block->getTerminator()),
sub_block->getOperations(), sub_block->begin(),
mlir::Block::iterator(sub_block->getTerminator()));
}
TF_RETURN_IF_ERROR(LiftVariables());
return std::move(module_);
}
StatusOr<mlir::OwningModuleRef> SavedModelV1Importer::ConvertSignature(
const GraphImportConfig& specs, llvm::StringRef func_name,
const SignatureDef& signature_def, const GraphDef& sub_graph_def,
const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def) {
// Convert this sub graphdef to sub graph
GraphConstructorOptions options;
options.allow_internal_ops = true;
options.add_default_attributes = true;
Graph sub_graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph));
// Convert the sub graphdef to a MLIR function.
return GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info,
flib_def, specs, func_name);
}
// Create GlobalTensorOp for each variable and move each VarHandle op to
// the enclosing function's arugments.
Status SavedModelV1Importer::LiftVariables() {
llvm::SmallVector<mlir::TF::VarHandleOp, 4> ops;
bool contains_ref_variable = false;
module_->walk([&ops, &contains_ref_variable](mlir::Operation* op) {
if (auto var_handle_op = llvm::dyn_cast<mlir::TF::VarHandleOp>(op))
ops.push_back(var_handle_op);
else if (op->getName().getStringRef() == "tf.VariableV2")
contains_ref_variable = true;
});
if (contains_ref_variable)
return errors::InvalidArgument(
"Ref variable created by VariableV2 is not supported.");
if (ops.empty()) return Status::OK();
TF_RETURN_IF_ERROR(ReadVariablesFromSession(ops));
for (auto op : ops) LiftVariable(op);
return Status::OK();
}
// Move the result of the VarHandleOp to the enclosing function's arugment list
// and erase this VarHandleOp.
void SavedModelV1Importer::LiftVariable(mlir::TF::VarHandleOp op) {
mlir::OpBuilder builder(&module_->getBodyRegion());
auto func_op = op.getParentOfType<mlir::FuncOp>();
builder.setInsertionPoint(func_op);
auto func_type = func_op.getType();
// Create the new function type by adding variable type to the arguments.
llvm::SmallVector<mlir::Type, 4> new_input_types(
func_type.getInputs().begin(), func_type.getInputs().end());
new_input_types.push_back(op.resource()->getType());
auto new_func_type =
builder.getFunctionType(new_input_types, func_type.getResults());
auto new_func_op = builder.create<mlir::FuncOp>(
func_op.getLoc(), func_op.getName(), new_func_type,
llvm::ArrayRef<mlir::NamedAttribute>());
// Bind the argument to the corresponding global tensor op.
new_func_op.setArgAttr(new_func_op.getNumArguments() - 1,
"tf_saved_model.bound_input",
builder.getSymbolRefAttr(op.shared_name()));
// Replace the function body and update its signature.
auto& new_region = new_func_op.getBody();
new_region.getBlocks().splice(new_region.end(),
func_op.getBody().getBlocks());
func_op.getOperation()->erase();
auto& new_block = new_region.front();
auto new_value = new_block.addArgument(op.resource()->getType());
op.getOperation()->replaceAllUsesWith(llvm::ArrayRef<mlir::Value>(new_value));
op.getOperation()->erase();
}
// Read all variables from the SavedModel through session, and create
// GlobalTensorOp for these variables.
Status SavedModelV1Importer::ReadVariablesFromSession(
const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops) {
mlir::OpBuilder builder(&module_->getBodyRegion());
// Find all variables and their corresponding read ops.
llvm::MapVector<llvm::StringRef, mlir::TF::VarHandleOp>
variable_names_and_ops;
for (auto op : ops) {
variable_names_and_ops[op.shared_name()] = op;
}
// Read all resource variables from the session.
std::vector<std::string> variable_names;
variable_names.reserve(variable_names_and_ops.size());
for (const auto& name_and_location : variable_names_and_ops)
variable_names.push_back(name_and_location.first);
std::vector<Tensor> resource_tensors;
TF_RETURN_IF_ERROR(bundle_.GetSession()->Run(
/*inputs=*/{}, variable_names,
/*target_node_names=*/{}, &resource_tensors));
const DeviceMgr* device_manager;
TF_RETURN_IF_ERROR(bundle_.GetSession()->LocalDeviceManager(&device_manager));
// Read all underlying tensors of the variables from the session.
std::vector<Tensor> tensors;
tensors.reserve(resource_tensors.size());
for (const auto& resource_tensor : resource_tensors) {
const auto& resource_handle = resource_tensor.scalar<ResourceHandle>()();
Device* device;
TF_RETURN_IF_ERROR(
device_manager->LookupDevice(resource_handle.device(), &device));
Var* var_ptr;
TF_RETURN_IF_ERROR(device->resource_manager()->Lookup(
resource_handle.container(), resource_handle.name(), &var_ptr));
core::RefCountPtr<Var> var(var_ptr);
// The variable tensor is already loaded into corresponding device's
// resource manager when we load the saved model using LoadSavedModel().
// Here we just read its value.
mutex_lock ml(*var->mu());
tensors.push_back(*var->tensor());
}
for (const auto& iter : llvm::zip(variable_names_and_ops, tensors)) {
const auto& name = std::get<0>(iter).first;
auto location = std::get<0>(iter).second.getLoc();
const auto& tensor = std::get<1>(iter);
// Create tensor attribute for this variable.
TF_ASSIGN_OR_RETURN(auto tensor_attr, ConvertTensor(tensor, &builder));
builder.create<mlir::tf_saved_model::GlobalTensorOp>(
location, builder.getStringAttr(name), tensor_attr,
mlir::TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
}
return Status::OK();
}
GraphImportConfig::InputArrays SavedModelV1Importer::ParseInputArrays(
const tensorflow::protobuf::Map<std::string, TensorInfo>& inputs) {
GraphImportConfig::InputArrays results;
for (const auto& iter : inputs) {
const auto& tensor_info = iter.second;
// Only dense tensor is supported.
DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
ArrayInfo array_info;
array_info.imported_dtype = tensor_info.dtype();
array_info.shape = tensor_info.tensor_shape();
std::vector<std::string> node_names =
absl::StrSplit(tensor_info.name(), ':');
results.insert(std::pair<std::string, ArrayInfo>(node_names.at(0),
std::move(array_info)));
}
return results;
}
std::vector<std::string> SavedModelV1Importer::ParseOutputArrays(
const tensorflow::protobuf::Map<std::string, TensorInfo>& outputs) {
std::vector<std::string> results;
for (const auto& iter : outputs) {
const auto& tensor_info = iter.second;
std::vector<std::string> node_names =
absl::StrSplit(tensor_info.name(), ':');
results.push_back(node_names.at(0));
}
return results;
}
} // namespace } // namespace
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) { Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) {
@ -2806,7 +3095,8 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
UpgradeLegacyGraph(const_cast<Graph*>(&graph), UpgradeLegacyGraph(const_cast<Graph*>(&graph),
const_cast<FunctionLibraryDefinition*>(&flib_def))); const_cast<FunctionLibraryDefinition*>(&flib_def)));
} }
return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs); return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
/* func_name = */ "main");
} }
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir( StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
@ -2816,6 +3106,11 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
add_default_attributes); add_default_attributes);
} }
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
return SavedModelV1Importer::Convert(saved_model, context);
}
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
std::string txt_module; std::string txt_module;
{ {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project
#include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
@ -50,6 +51,12 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes = true); absl::Span<std::string> exported_names, bool add_default_attributes = true);
// Given a V1 SavedModel, returns a MLIR module containing the functions,
// expressed with tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef>
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
mlir::MLIRContext* context);
// Serialize a MLIR module to a string. // Serialize a MLIR module to a string.
std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false); std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false);

View File

@ -130,6 +130,27 @@ mlir::OwningModuleRef SavedModelToMlirImport(
return module_or.ConsumeValueOrDie(); return module_or.ConsumeValueOrDie();
} }
mlir::OwningModuleRef SavedModelV1ToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
tensorflow::SavedModelBundle bundle;
auto load_status = tensorflow::LoadSavedModel(
/* session_options = */ {}, /* run_options = */ {},
std::string(saved_model_dir), tags, &bundle);
if (!load_status.ok()) {
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
<< "': " << load_status;
return nullptr;
}
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
if (!module_or.status().ok()) {
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
return nullptr;
}
return module_or.ConsumeValueOrDie();
}
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
llvm::StringRef input, absl::string_view debug_info_file, llvm::StringRef input, absl::string_view debug_info_file,
absl::string_view input_arrays, absl::string_view input_dtypes, absl::string_view input_arrays, absl::string_view input_dtypes,

View File

@ -54,6 +54,14 @@ mlir::OwningModuleRef SavedModelToMlirImport(
absl::string_view saved_model_dir, absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, const std::unordered_set<std::string>& tags,
absl::Span<std::string> exported_names, mlir::MLIRContext* context); absl::Span<std::string> exported_names, mlir::MLIRContext* context);
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
// given MLIR `context`.
mlir::OwningModuleRef SavedModelV1ToMlirImport(
absl::string_view saved_model_dir,
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_

View File

@ -54,6 +54,12 @@ static llvm::cl::opt<bool> import_saved_model(
llvm::cl::desc("Import a saved model to its MLIR representation"), llvm::cl::desc("Import a saved model to its MLIR representation"),
llvm::cl::value_desc("dir")); llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> import_saved_model_v1(
"savedmodel-v1-to-mlir",
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
llvm::cl::value_desc("dir"));
// NOLINTNEXTLINE // NOLINTNEXTLINE
static llvm::cl::opt<std::string> saved_model_tags( static llvm::cl::opt<std::string> saved_model_tags(
"tf-savedmodel-tags", "tf-savedmodel-tags",
@ -77,10 +83,11 @@ int main(int argc, char** argv) {
llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n"); llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n");
if (!import_saved_model && !requested_translation) { if (!import_saved_model && !import_saved_model_v1 && !requested_translation) {
llvm::errs() << "error: need to specify one translation to perform\n"; llvm::errs() << "error: need to specify one translation to perform\n";
return 1; return 1;
} else if (import_saved_model && requested_translation) { } else if (import_saved_model && import_saved_model_v1 &&
requested_translation) {
llvm::errs() llvm::errs()
<< "error: cannot specify more than one translation to perform\n"; << "error: cannot specify more than one translation to perform\n";
return 1; return 1;
@ -105,6 +112,16 @@ int main(int argc, char** argv) {
&context); &context);
if (!module) return 1; if (!module) return 1;
module->print(output->os());
} else if (import_saved_model_v1) {
std::unordered_set<std::string> tags =
absl::StrSplit(saved_model_tags, ',');
mlir::MLIRContext context;
auto module =
tensorflow::SavedModelV1ToMlirImport(input_filename, tags, &context);
if (!module) return 1;
module->print(output->os()); module->print(output->os());
} else { } else {
auto input = mlir::openInputFile(input_filename, &error_message); auto input = mlir::openInputFile(input_filename, &error_message);