Initial version of SavedModel V1 Importer that converts a V1 SavedModel to a
MLIR Module that contains functions specified by signature defs. PiperOrigin-RevId: 288042933 Change-Id: I5dfde397eb8635020025aa1dc6fee690e4b45ae3
This commit is contained in:
parent
05dd398ea5
commit
d25dd80748
@ -108,6 +108,45 @@ string ExperimentalConvertSavedModelToMlir(
|
|||||||
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// saved_model_path: File path from which to load the SavedModel.
|
||||||
|
// tags: Tags to identify MetaGraphDef that need to be loaded.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// A string of textual MLIR representing the raw imported SavedModel.
|
||||||
|
string ExperimentalConvertSavedModelV1ToMlir(
|
||||||
|
const string &saved_model_path,
|
||||||
|
const string &tags,
|
||||||
|
bool show_debug_info,
|
||||||
|
TF_Status* status) {
|
||||||
|
// Load the saved model into a SavedModelBundle.
|
||||||
|
|
||||||
|
std::unordered_set<string> tag_set
|
||||||
|
= absl::StrSplit(tags, ',', absl::SkipEmpty());
|
||||||
|
|
||||||
|
tensorflow::SavedModelBundle bundle;
|
||||||
|
auto load_status = tensorflow::LoadSavedModel(
|
||||||
|
{}, {},
|
||||||
|
saved_model_path, tag_set, &bundle);
|
||||||
|
if (!load_status.ok()) {
|
||||||
|
Set_TF_Status_from_Status(status, load_status);
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the SavedModelBundle to an MLIR module.
|
||||||
|
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
auto module_or = ConvertSavedModelV1ToMlir(bundle, &context);
|
||||||
|
if (!module_or.status().ok()) {
|
||||||
|
Set_TF_Status_from_Status(status, module_or.status());
|
||||||
|
return "// error";
|
||||||
|
}
|
||||||
|
|
||||||
|
return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
string ExperimentalRunPassPipeline(
|
string ExperimentalRunPassPipeline(
|
||||||
const string &mlir_txt,
|
const string &mlir_txt,
|
||||||
@ -154,6 +193,7 @@ string ExperimentalRunPassPipeline(
|
|||||||
%unignore tensorflow::swig;
|
%unignore tensorflow::swig;
|
||||||
%unignore tensorflow::swig::ImportGraphDef;
|
%unignore tensorflow::swig::ImportGraphDef;
|
||||||
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
|
%unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir;
|
||||||
|
%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir;
|
||||||
%unignore tensorflow::swig::ExperimentalRunPassPipeline;
|
%unignore tensorflow::swig::ExperimentalRunPassPipeline;
|
||||||
|
|
||||||
// Wrap this function
|
// Wrap this function
|
||||||
@ -167,6 +207,11 @@ static string ExperimentalConvertSavedModelToMlir(
|
|||||||
const string &exported_names,
|
const string &exported_names,
|
||||||
bool show_debug_info,
|
bool show_debug_info,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
static string ExperimentalConvertSavedModelV1ToMlir(
|
||||||
|
const string &saved_model_path,
|
||||||
|
const string &tags,
|
||||||
|
bool show_debug_info,
|
||||||
|
TF_Status* status);
|
||||||
static string ExperimentalRunPassPipeline(
|
static string ExperimentalRunPassPipeline(
|
||||||
const string &mlir_txt,
|
const string &mlir_txt,
|
||||||
const string &pass_pipeline,
|
const string &pass_pipeline,
|
||||||
@ -188,6 +233,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path,
|
|||||||
show_debug_info
|
show_debug_info
|
||||||
).decode('utf-8');
|
).decode('utf-8');
|
||||||
|
|
||||||
|
def experimental_convert_saved_model_v1_to_mlir(saved_model_path,
|
||||||
|
tags, show_debug_info):
|
||||||
|
return ExperimentalConvertSavedModelV1ToMlir(
|
||||||
|
str(saved_model_path).encode('utf-8'),
|
||||||
|
str(tags).encode('utf-8'),
|
||||||
|
show_debug_info
|
||||||
|
).decode('utf-8');
|
||||||
|
|
||||||
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
|
def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
|
||||||
return ExperimentalRunPassPipeline(
|
return ExperimentalRunPassPipeline(
|
||||||
mlir_txt.encode('utf-8'),
|
mlir_txt.encode('utf-8'),
|
||||||
|
@ -348,15 +348,18 @@ cc_library(
|
|||||||
":tensorflow",
|
":tensorflow",
|
||||||
":tensorflow_passes",
|
":tensorflow_passes",
|
||||||
"//tensorflow/cc/saved_model:bundle_v2",
|
"//tensorflow/cc/saved_model:bundle_v2",
|
||||||
|
"//tensorflow/cc/saved_model:loader_lite",
|
||||||
"//tensorflow/compiler/jit:shape_inference_helpers",
|
"//tensorflow/compiler/jit:shape_inference_helpers",
|
||||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||||
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
|
"//tensorflow/compiler/tf2xla:functionalize_control_flow",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
@ -13,6 +13,15 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "common_v1",
|
||||||
|
srcs = ["common_v1.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "test_utilities",
|
name = "test_utilities",
|
||||||
testonly = True,
|
testonly = True,
|
||||||
@ -24,7 +33,10 @@ filegroup(
|
|||||||
# Drop trailing ".py" from all test file names.
|
# Drop trailing ".py" from all test file names.
|
||||||
all_test_basenames = [py[:-3] for py in glob(
|
all_test_basenames = [py[:-3] for py in glob(
|
||||||
["*.py"],
|
["*.py"],
|
||||||
exclude = ["common.py"],
|
exclude = [
|
||||||
|
"common.py",
|
||||||
|
"common_v1.py",
|
||||||
|
],
|
||||||
)]
|
)]
|
||||||
|
|
||||||
# Instantiate all the tests.
|
# Instantiate all the tests.
|
||||||
|
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# RUN: %p/basic_v1 | FileCheck %s
|
||||||
|
|
||||||
|
# pylint: disable=missing-docstring,line-too-long
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||||
|
|
||||||
|
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
|
||||||
|
# CHECK: func @basic([[ARG0:%.*]]: tensor<3x1xf32>,
|
||||||
|
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32>
|
||||||
|
# CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor<!tf.resource<tensor<1x3xf32>>>) -> tensor<1x3xf32>
|
||||||
|
# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
|
||||||
|
# CHECK-NEXT: return [[R1]] : tensor<3x3xf32>
|
||||||
|
|
||||||
|
|
||||||
|
def Test():
|
||||||
|
|
||||||
|
# Default TF1.x uses reference variables that are not supported by SavedModel
|
||||||
|
# v1 Importer. To use SavedModel V1 Importer, resource variables should be
|
||||||
|
# enabled.
|
||||||
|
tf.compat.v1.enable_resource_variables()
|
||||||
|
|
||||||
|
tf.compat.v1.disable_eager_execution()
|
||||||
|
|
||||||
|
x = tf.constant([[1.0], [1.0], [1.0]])
|
||||||
|
y = tf.compat.v1.get_variable(
|
||||||
|
name='y',
|
||||||
|
shape=(1, 3),
|
||||||
|
initializer=tf.random_normal_initializer(),
|
||||||
|
trainable=True)
|
||||||
|
r = tf.matmul(x, y)
|
||||||
|
|
||||||
|
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
|
||||||
|
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'basic':
|
||||||
|
(tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||||
|
inputs={'x': tensor_info_x},
|
||||||
|
outputs={'r': tensor_info_r},
|
||||||
|
method_name=tf.saved_model.PREDICT_METHOD_NAME))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
common_v1.do_test(Test())
|
@ -11,6 +11,7 @@ def tf_saved_model_test(name, data):
|
|||||||
srcs = [name + ".py"],
|
srcs = [name + ".py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common",
|
"//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common_v1",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Serves as a common "main" function for all the SavedModel tests.
|
||||||
|
|
||||||
|
There is a fair amount of setup needed to initialize tensorflow and get it
|
||||||
|
into a proper TF2 execution mode. This hides that boilerplate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
from absl import logging
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
|
||||||
|
# Use /tmp to make debugging the tests easier (see README.md)
|
||||||
|
flags.DEFINE_string('save_model_path', '', 'Path to save the model to.')
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
|
||||||
|
# This function needs to take a "create_module_fn", as opposed to just the
|
||||||
|
# module itself, because the creation of the module has to be delayed until
|
||||||
|
# after absl and tensorflow have run various initialization steps.
|
||||||
|
def do_test(signature_def_map, show_debug_info=False):
|
||||||
|
"""Runs test.
|
||||||
|
|
||||||
|
1. Performs absl and tf "main"-like initialization that must run before almost
|
||||||
|
anything else.
|
||||||
|
2. Converts signature_def_map to SavedModel V1
|
||||||
|
3. Converts SavedModel V1 to MLIR
|
||||||
|
4. Prints the textual MLIR to stdout (it is expected that the caller will have
|
||||||
|
FileCheck checks in its file to check this output).
|
||||||
|
|
||||||
|
This is only for use by the MLIR SavedModel importer tests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signature_def_map: A map from string key to signature_def. The key will be
|
||||||
|
used as function name in the resulting MLIR.
|
||||||
|
show_debug_info: If true, shows debug locations in the resulting MLIR.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Make LOG(ERROR) in C++ code show up on the console.
|
||||||
|
# All `Status` passed around in the C++ API seem to eventually go into
|
||||||
|
# `LOG(ERROR)`, so this makes them print out by default.
|
||||||
|
logging.set_stderrthreshold('error')
|
||||||
|
|
||||||
|
def app_main(argv):
|
||||||
|
"""Function passed to absl.app.run."""
|
||||||
|
if len(argv) > 1:
|
||||||
|
raise app.UsageError('Too many command-line arguments.')
|
||||||
|
if FLAGS.save_model_path:
|
||||||
|
save_model_path = FLAGS.save_model_path
|
||||||
|
else:
|
||||||
|
save_model_path = tempfile.mktemp(suffix='.saved_model')
|
||||||
|
|
||||||
|
sess = tf.Session()
|
||||||
|
sess.run(tf.initializers.global_variables())
|
||||||
|
builder = tf.saved_model.builder.SavedModelBuilder(save_model_path)
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, [tf.saved_model.tag_constants.SERVING],
|
||||||
|
signature_def_map,
|
||||||
|
strip_default_attrs=True)
|
||||||
|
builder.save()
|
||||||
|
|
||||||
|
logging.info('Saved model to: %s', save_model_path)
|
||||||
|
mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir(
|
||||||
|
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||||
|
show_debug_info)
|
||||||
|
# We don't strictly need this, but it serves as a handy sanity check
|
||||||
|
# for that API, which is otherwise a bit annoying to test.
|
||||||
|
# The canonicalization shouldn't affect these tests in any way.
|
||||||
|
mlir = pywrap_tensorflow.experimental_run_pass_pipeline(
|
||||||
|
mlir, 'tf-standard-pipeline', show_debug_info)
|
||||||
|
print(mlir)
|
||||||
|
|
||||||
|
app.run(app_main)
|
@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# RUN: %p/shared_variable_v1 | FileCheck %s
|
||||||
|
|
||||||
|
# pylint: disable=missing-docstring,line-too-long
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
|
||||||
|
|
||||||
|
# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> ()
|
||||||
|
# CHECK: func {{@.*}}([[ARG0:%.*]]: tensor<3x1xf32>,
|
||||||
|
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32>
|
||||||
|
|
||||||
|
# CHECK: func {{@.*}}([[ARG2:%.*]]: tensor<3x1xf32>,
|
||||||
|
# CHECK-SAME: [[ARG3:%.*]]: tensor<!tf.resource<tensor<1x3xf32>>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32>
|
||||||
|
|
||||||
|
|
||||||
|
def Test():
|
||||||
|
|
||||||
|
# Default TF1.x uses reference variables that are not supported by SavedModel
|
||||||
|
# v1 Importer. To use SavedModel V1 Importer, resource variables should be
|
||||||
|
# enabled.
|
||||||
|
tf.enable_resource_variables()
|
||||||
|
|
||||||
|
tf.compat.v1.disable_eager_execution()
|
||||||
|
|
||||||
|
x = tf.constant([[1.0], [1.0], [1.0]])
|
||||||
|
y = tf.get_variable(
|
||||||
|
name='y',
|
||||||
|
shape=(1, 3),
|
||||||
|
initializer=tf.random_normal_initializer(),
|
||||||
|
trainable=True)
|
||||||
|
r = tf.matmul(x, y)
|
||||||
|
|
||||||
|
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
|
||||||
|
tensor_info_r = tf.saved_model.utils.build_tensor_info(r)
|
||||||
|
|
||||||
|
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
|
||||||
|
inputs={'x': tensor_info_x},
|
||||||
|
outputs={'r': tensor_info_r},
|
||||||
|
method_name=tf.saved_model.PREDICT_METHOD_NAME)
|
||||||
|
|
||||||
|
# Create two signatures that share the same variable.
|
||||||
|
return {'basic': signature_def, 'basic_2': signature_def}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
common_v1.do_test(Test())
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/ADT/Twine.h"
|
#include "llvm/ADT/Twine.h"
|
||||||
@ -71,6 +72,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/resource_var.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
@ -81,6 +83,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
|
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -1734,8 +1737,8 @@ class GraphDefImporter : public ImporterBase {
|
|||||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||||
mlir::MLIRContext* context, const Graph& graph,
|
mlir::MLIRContext* context, const Graph& graph,
|
||||||
const GraphDebugInfo& debug_info,
|
const GraphDebugInfo& debug_info,
|
||||||
const FunctionLibraryDefinition& flib_def,
|
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
|
||||||
const GraphImportConfig& specs);
|
llvm::StringRef func_name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit GraphDefImporter(
|
explicit GraphDefImporter(
|
||||||
@ -1773,7 +1776,7 @@ class GraphDefImporter : public ImporterBase {
|
|||||||
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||||
mlir::MLIRContext* context, const Graph& graph,
|
mlir::MLIRContext* context, const Graph& graph,
|
||||||
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
||||||
const GraphImportConfig& specs) {
|
const GraphImportConfig& specs, llvm::StringRef func_name) {
|
||||||
mlir::OwningModuleRef module =
|
mlir::OwningModuleRef module =
|
||||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||||
@ -1861,7 +1864,7 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
|||||||
{producer, min_consumer, bad_consumers})));
|
{producer, min_consumer, bad_consumers})));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
|
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
|
||||||
"main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
|
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
|
||||||
resource_arg_unique_ids));
|
resource_arg_unique_ids));
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
@ -2771,6 +2774,292 @@ StatusOr<mlir::OwningModuleRef> SavedModelImporter::Convert(
|
|||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A helper class to import a TensorFlow model expressed in SavedModel V1 into
|
||||||
|
// an MLIR Module.
|
||||||
|
class SavedModelV1Importer {
|
||||||
|
public:
|
||||||
|
// Main entry point: converts all functions (specified by SignatureDefs) in
|
||||||
|
// the given meta graph to an MLIR Module.
|
||||||
|
static StatusOr<mlir::OwningModuleRef> Convert(const SavedModelBundle& bundle,
|
||||||
|
mlir::MLIRContext* context) {
|
||||||
|
SavedModelV1Importer importer(bundle, context);
|
||||||
|
|
||||||
|
return importer.ConvertSignatures();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SavedModelV1Importer(const SavedModelBundle& bundle,
|
||||||
|
mlir::MLIRContext* context)
|
||||||
|
: bundle_(bundle),
|
||||||
|
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {}
|
||||||
|
|
||||||
|
// Convert the SavedModel to TF Executor Dialect. It creates a MLIR function
|
||||||
|
// for each signature.
|
||||||
|
StatusOr<mlir::OwningModuleRef> ConvertSignatures();
|
||||||
|
StatusOr<mlir::OwningModuleRef> ConvertSignature(
|
||||||
|
const GraphImportConfig& specs, llvm::StringRef func_name,
|
||||||
|
const SignatureDef& signature_def, const GraphDef& sub_graph_def,
|
||||||
|
const GraphDebugInfo& debug_info,
|
||||||
|
const FunctionLibraryDefinition& flib_def);
|
||||||
|
|
||||||
|
// Create GlobalTensorOp for each variable and move each VarHandle op to
|
||||||
|
// the enclosing function's arugments.
|
||||||
|
Status LiftVariables();
|
||||||
|
void LiftVariable(mlir::TF::VarHandleOp op);
|
||||||
|
|
||||||
|
// Read all variables from the SavedModel through session, and create
|
||||||
|
// GlobalTensorOp for these variables.
|
||||||
|
Status ReadVariablesFromSession(
|
||||||
|
const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops);
|
||||||
|
|
||||||
|
GraphImportConfig::InputArrays ParseInputArrays(
|
||||||
|
const tensorflow::protobuf::Map<std::string, TensorInfo>& inputs);
|
||||||
|
|
||||||
|
std::vector<std::string> ParseOutputArrays(
|
||||||
|
const tensorflow::protobuf::Map<std::string, TensorInfo>& outputs);
|
||||||
|
|
||||||
|
const SavedModelBundle& bundle_;
|
||||||
|
mlir::OwningModuleRef module_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert the SavedModel to TF Executor Dialect. It creates a MLIR function
|
||||||
|
// for each signature.
|
||||||
|
StatusOr<mlir::OwningModuleRef> SavedModelV1Importer::ConvertSignatures() {
|
||||||
|
const auto& signatures = bundle_.GetSignatures();
|
||||||
|
const auto& graphdef = bundle_.meta_graph_def.graph_def();
|
||||||
|
|
||||||
|
FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library());
|
||||||
|
|
||||||
|
// debug_info might not be loaded with loader_lite.
|
||||||
|
GraphDebugInfo debug_info;
|
||||||
|
if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info;
|
||||||
|
|
||||||
|
for (const auto& key_and_signature_def : signatures) {
|
||||||
|
const auto& func_name = key_and_signature_def.first;
|
||||||
|
const auto& signature_def = key_and_signature_def.second;
|
||||||
|
GraphImportConfig specs;
|
||||||
|
specs.inputs = ParseInputArrays(signature_def.inputs());
|
||||||
|
specs.outputs = ParseOutputArrays(signature_def.outputs());
|
||||||
|
|
||||||
|
// Remove unused nodes and create a sub graphdef.
|
||||||
|
GraphDef sub_graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph(
|
||||||
|
graphdef, &sub_graph_def,
|
||||||
|
/* terminal_nodes = */ {specs.outputs.begin(), specs.outputs.end()}));
|
||||||
|
|
||||||
|
auto status_or_sub_module = ConvertSignature(
|
||||||
|
specs, func_name, signature_def, sub_graph_def, debug_info, flib_def);
|
||||||
|
if (!status_or_sub_module.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to convert SignatureDef for " << func_name << ": "
|
||||||
|
<< status_or_sub_module.status();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& sub_module = status_or_sub_module.ValueOrDie();
|
||||||
|
|
||||||
|
// Move the converted functions to top level MLIR module.
|
||||||
|
auto* block = module_->getBody();
|
||||||
|
auto* sub_block = sub_module->getBody();
|
||||||
|
block->getOperations().splice(
|
||||||
|
mlir::Block::iterator(block->getTerminator()),
|
||||||
|
sub_block->getOperations(), sub_block->begin(),
|
||||||
|
mlir::Block::iterator(sub_block->getTerminator()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(LiftVariables());
|
||||||
|
|
||||||
|
return std::move(module_);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<mlir::OwningModuleRef> SavedModelV1Importer::ConvertSignature(
|
||||||
|
const GraphImportConfig& specs, llvm::StringRef func_name,
|
||||||
|
const SignatureDef& signature_def, const GraphDef& sub_graph_def,
|
||||||
|
const GraphDebugInfo& debug_info,
|
||||||
|
const FunctionLibraryDefinition& flib_def) {
|
||||||
|
// Convert this sub graphdef to sub graph
|
||||||
|
GraphConstructorOptions options;
|
||||||
|
options.allow_internal_ops = true;
|
||||||
|
options.add_default_attributes = true;
|
||||||
|
Graph sub_graph(OpRegistry::Global());
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph));
|
||||||
|
|
||||||
|
// Convert the sub graphdef to a MLIR function.
|
||||||
|
return GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info,
|
||||||
|
flib_def, specs, func_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create GlobalTensorOp for each variable and move each VarHandle op to
|
||||||
|
// the enclosing function's arugments.
|
||||||
|
Status SavedModelV1Importer::LiftVariables() {
|
||||||
|
llvm::SmallVector<mlir::TF::VarHandleOp, 4> ops;
|
||||||
|
|
||||||
|
bool contains_ref_variable = false;
|
||||||
|
|
||||||
|
module_->walk([&ops, &contains_ref_variable](mlir::Operation* op) {
|
||||||
|
if (auto var_handle_op = llvm::dyn_cast<mlir::TF::VarHandleOp>(op))
|
||||||
|
ops.push_back(var_handle_op);
|
||||||
|
else if (op->getName().getStringRef() == "tf.VariableV2")
|
||||||
|
contains_ref_variable = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (contains_ref_variable)
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Ref variable created by VariableV2 is not supported.");
|
||||||
|
|
||||||
|
if (ops.empty()) return Status::OK();
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ReadVariablesFromSession(ops));
|
||||||
|
|
||||||
|
for (auto op : ops) LiftVariable(op);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the result of the VarHandleOp to the enclosing function's arugment list
|
||||||
|
// and erase this VarHandleOp.
|
||||||
|
void SavedModelV1Importer::LiftVariable(mlir::TF::VarHandleOp op) {
|
||||||
|
mlir::OpBuilder builder(&module_->getBodyRegion());
|
||||||
|
|
||||||
|
auto func_op = op.getParentOfType<mlir::FuncOp>();
|
||||||
|
builder.setInsertionPoint(func_op);
|
||||||
|
|
||||||
|
auto func_type = func_op.getType();
|
||||||
|
|
||||||
|
// Create the new function type by adding variable type to the arguments.
|
||||||
|
llvm::SmallVector<mlir::Type, 4> new_input_types(
|
||||||
|
func_type.getInputs().begin(), func_type.getInputs().end());
|
||||||
|
new_input_types.push_back(op.resource()->getType());
|
||||||
|
auto new_func_type =
|
||||||
|
builder.getFunctionType(new_input_types, func_type.getResults());
|
||||||
|
|
||||||
|
auto new_func_op = builder.create<mlir::FuncOp>(
|
||||||
|
func_op.getLoc(), func_op.getName(), new_func_type,
|
||||||
|
llvm::ArrayRef<mlir::NamedAttribute>());
|
||||||
|
|
||||||
|
// Bind the argument to the corresponding global tensor op.
|
||||||
|
new_func_op.setArgAttr(new_func_op.getNumArguments() - 1,
|
||||||
|
"tf_saved_model.bound_input",
|
||||||
|
builder.getSymbolRefAttr(op.shared_name()));
|
||||||
|
|
||||||
|
// Replace the function body and update its signature.
|
||||||
|
auto& new_region = new_func_op.getBody();
|
||||||
|
new_region.getBlocks().splice(new_region.end(),
|
||||||
|
func_op.getBody().getBlocks());
|
||||||
|
|
||||||
|
func_op.getOperation()->erase();
|
||||||
|
|
||||||
|
auto& new_block = new_region.front();
|
||||||
|
auto new_value = new_block.addArgument(op.resource()->getType());
|
||||||
|
|
||||||
|
op.getOperation()->replaceAllUsesWith(llvm::ArrayRef<mlir::Value>(new_value));
|
||||||
|
|
||||||
|
op.getOperation()->erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read all variables from the SavedModel through session, and create
|
||||||
|
// GlobalTensorOp for these variables.
|
||||||
|
Status SavedModelV1Importer::ReadVariablesFromSession(
|
||||||
|
const llvm::SmallVectorImpl<mlir::TF::VarHandleOp>& ops) {
|
||||||
|
mlir::OpBuilder builder(&module_->getBodyRegion());
|
||||||
|
|
||||||
|
// Find all variables and their corresponding read ops.
|
||||||
|
|
||||||
|
llvm::MapVector<llvm::StringRef, mlir::TF::VarHandleOp>
|
||||||
|
variable_names_and_ops;
|
||||||
|
for (auto op : ops) {
|
||||||
|
variable_names_and_ops[op.shared_name()] = op;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read all resource variables from the session.
|
||||||
|
|
||||||
|
std::vector<std::string> variable_names;
|
||||||
|
variable_names.reserve(variable_names_and_ops.size());
|
||||||
|
for (const auto& name_and_location : variable_names_and_ops)
|
||||||
|
variable_names.push_back(name_and_location.first);
|
||||||
|
|
||||||
|
std::vector<Tensor> resource_tensors;
|
||||||
|
TF_RETURN_IF_ERROR(bundle_.GetSession()->Run(
|
||||||
|
/*inputs=*/{}, variable_names,
|
||||||
|
/*target_node_names=*/{}, &resource_tensors));
|
||||||
|
|
||||||
|
const DeviceMgr* device_manager;
|
||||||
|
TF_RETURN_IF_ERROR(bundle_.GetSession()->LocalDeviceManager(&device_manager));
|
||||||
|
|
||||||
|
// Read all underlying tensors of the variables from the session.
|
||||||
|
std::vector<Tensor> tensors;
|
||||||
|
tensors.reserve(resource_tensors.size());
|
||||||
|
for (const auto& resource_tensor : resource_tensors) {
|
||||||
|
const auto& resource_handle = resource_tensor.scalar<ResourceHandle>()();
|
||||||
|
|
||||||
|
Device* device;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
device_manager->LookupDevice(resource_handle.device(), &device));
|
||||||
|
|
||||||
|
Var* var_ptr;
|
||||||
|
TF_RETURN_IF_ERROR(device->resource_manager()->Lookup(
|
||||||
|
resource_handle.container(), resource_handle.name(), &var_ptr));
|
||||||
|
core::RefCountPtr<Var> var(var_ptr);
|
||||||
|
|
||||||
|
// The variable tensor is already loaded into corresponding device's
|
||||||
|
// resource manager when we load the saved model using LoadSavedModel().
|
||||||
|
// Here we just read its value.
|
||||||
|
mutex_lock ml(*var->mu());
|
||||||
|
tensors.push_back(*var->tensor());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& iter : llvm::zip(variable_names_and_ops, tensors)) {
|
||||||
|
const auto& name = std::get<0>(iter).first;
|
||||||
|
auto location = std::get<0>(iter).second.getLoc();
|
||||||
|
const auto& tensor = std::get<1>(iter);
|
||||||
|
|
||||||
|
// Create tensor attribute for this variable.
|
||||||
|
TF_ASSIGN_OR_RETURN(auto tensor_attr, ConvertTensor(tensor, &builder));
|
||||||
|
|
||||||
|
builder.create<mlir::tf_saved_model::GlobalTensorOp>(
|
||||||
|
location, builder.getStringAttr(name), tensor_attr,
|
||||||
|
mlir::TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphImportConfig::InputArrays SavedModelV1Importer::ParseInputArrays(
|
||||||
|
const tensorflow::protobuf::Map<std::string, TensorInfo>& inputs) {
|
||||||
|
GraphImportConfig::InputArrays results;
|
||||||
|
for (const auto& iter : inputs) {
|
||||||
|
const auto& tensor_info = iter.second;
|
||||||
|
|
||||||
|
// Only dense tensor is supported.
|
||||||
|
DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName);
|
||||||
|
|
||||||
|
ArrayInfo array_info;
|
||||||
|
array_info.imported_dtype = tensor_info.dtype();
|
||||||
|
array_info.shape = tensor_info.tensor_shape();
|
||||||
|
|
||||||
|
std::vector<std::string> node_names =
|
||||||
|
absl::StrSplit(tensor_info.name(), ':');
|
||||||
|
|
||||||
|
results.insert(std::pair<std::string, ArrayInfo>(node_names.at(0),
|
||||||
|
std::move(array_info)));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> SavedModelV1Importer::ParseOutputArrays(
|
||||||
|
const tensorflow::protobuf::Map<std::string, TensorInfo>& outputs) {
|
||||||
|
std::vector<std::string> results;
|
||||||
|
for (const auto& iter : outputs) {
|
||||||
|
const auto& tensor_info = iter.second;
|
||||||
|
|
||||||
|
std::vector<std::string> node_names =
|
||||||
|
absl::StrSplit(tensor_info.name(), ':');
|
||||||
|
results.push_back(node_names.at(0));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) {
|
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) {
|
||||||
@ -2806,7 +3095,8 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
|||||||
UpgradeLegacyGraph(const_cast<Graph*>(&graph),
|
UpgradeLegacyGraph(const_cast<Graph*>(&graph),
|
||||||
const_cast<FunctionLibraryDefinition*>(&flib_def)));
|
const_cast<FunctionLibraryDefinition*>(&flib_def)));
|
||||||
}
|
}
|
||||||
return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs);
|
return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
|
||||||
|
/* func_name = */ "main");
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||||
@ -2816,6 +3106,11 @@ StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
|||||||
add_default_attributes);
|
add_default_attributes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<mlir::OwningModuleRef> ConvertSavedModelV1ToMlir(
|
||||||
|
const SavedModelBundle& saved_model, mlir::MLIRContext* context) {
|
||||||
|
return SavedModelV1Importer::Convert(saved_model, context);
|
||||||
|
}
|
||||||
|
|
||||||
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) {
|
||||||
std::string txt_module;
|
std::string txt_module;
|
||||||
{
|
{
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
#include "mlir/IR/Module.h" // TF:llvm-project
|
#include "mlir/IR/Module.h" // TF:llvm-project
|
||||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||||
|
#include "tensorflow/cc/saved_model/loader.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
@ -50,6 +51,12 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
|||||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||||
absl::Span<std::string> exported_names, bool add_default_attributes = true);
|
absl::Span<std::string> exported_names, bool add_default_attributes = true);
|
||||||
|
|
||||||
|
// Given a V1 SavedModel, returns a MLIR module containing the functions,
|
||||||
|
// expressed with tf_executor dialect.
|
||||||
|
stream_executor::port::StatusOr<mlir::OwningModuleRef>
|
||||||
|
ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model,
|
||||||
|
mlir::MLIRContext* context);
|
||||||
|
|
||||||
// Serialize a MLIR module to a string.
|
// Serialize a MLIR module to a string.
|
||||||
std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false);
|
std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false);
|
||||||
|
|
||||||
|
@ -130,6 +130,27 @@ mlir::OwningModuleRef SavedModelToMlirImport(
|
|||||||
return module_or.ConsumeValueOrDie();
|
return module_or.ConsumeValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlir::OwningModuleRef SavedModelV1ToMlirImport(
|
||||||
|
absl::string_view saved_model_dir,
|
||||||
|
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context) {
|
||||||
|
tensorflow::SavedModelBundle bundle;
|
||||||
|
auto load_status = tensorflow::LoadSavedModel(
|
||||||
|
/* session_options = */ {}, /* run_options = */ {},
|
||||||
|
std::string(saved_model_dir), tags, &bundle);
|
||||||
|
if (!load_status.ok()) {
|
||||||
|
LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir
|
||||||
|
<< "': " << load_status;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto module_or = ConvertSavedModelV1ToMlir(bundle, context);
|
||||||
|
if (!module_or.status().ok()) {
|
||||||
|
LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return module_or.ConsumeValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction(
|
||||||
llvm::StringRef input, absl::string_view debug_info_file,
|
llvm::StringRef input, absl::string_view debug_info_file,
|
||||||
absl::string_view input_arrays, absl::string_view input_dtypes,
|
absl::string_view input_arrays, absl::string_view input_dtypes,
|
||||||
|
@ -54,6 +54,14 @@ mlir::OwningModuleRef SavedModelToMlirImport(
|
|||||||
absl::string_view saved_model_dir,
|
absl::string_view saved_model_dir,
|
||||||
const std::unordered_set<std::string>& tags,
|
const std::unordered_set<std::string>& tags,
|
||||||
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
absl::Span<std::string> exported_names, mlir::MLIRContext* context);
|
||||||
|
|
||||||
|
// Converts a TensorFlow V1 SavedModel stored in the directory with the given
|
||||||
|
// `saved_model_dir` into a MLIR module. Creates MLIR entities into the
|
||||||
|
// given MLIR `context`.
|
||||||
|
mlir::OwningModuleRef SavedModelV1ToMlirImport(
|
||||||
|
absl::string_view saved_model_dir,
|
||||||
|
const std::unordered_set<std::string>& tags, mlir::MLIRContext* context);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_
|
||||||
|
@ -54,6 +54,12 @@ static llvm::cl::opt<bool> import_saved_model(
|
|||||||
llvm::cl::desc("Import a saved model to its MLIR representation"),
|
llvm::cl::desc("Import a saved model to its MLIR representation"),
|
||||||
llvm::cl::value_desc("dir"));
|
llvm::cl::value_desc("dir"));
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE
|
||||||
|
static llvm::cl::opt<bool> import_saved_model_v1(
|
||||||
|
"savedmodel-v1-to-mlir",
|
||||||
|
llvm::cl::desc("Import a saved model V1 to its MLIR representation"),
|
||||||
|
llvm::cl::value_desc("dir"));
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::opt<std::string> saved_model_tags(
|
static llvm::cl::opt<std::string> saved_model_tags(
|
||||||
"tf-savedmodel-tags",
|
"tf-savedmodel-tags",
|
||||||
@ -77,10 +83,11 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n");
|
llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n");
|
||||||
|
|
||||||
if (!import_saved_model && !requested_translation) {
|
if (!import_saved_model && !import_saved_model_v1 && !requested_translation) {
|
||||||
llvm::errs() << "error: need to specify one translation to perform\n";
|
llvm::errs() << "error: need to specify one translation to perform\n";
|
||||||
return 1;
|
return 1;
|
||||||
} else if (import_saved_model && requested_translation) {
|
} else if (import_saved_model && import_saved_model_v1 &&
|
||||||
|
requested_translation) {
|
||||||
llvm::errs()
|
llvm::errs()
|
||||||
<< "error: cannot specify more than one translation to perform\n";
|
<< "error: cannot specify more than one translation to perform\n";
|
||||||
return 1;
|
return 1;
|
||||||
@ -105,6 +112,16 @@ int main(int argc, char** argv) {
|
|||||||
&context);
|
&context);
|
||||||
if (!module) return 1;
|
if (!module) return 1;
|
||||||
|
|
||||||
|
module->print(output->os());
|
||||||
|
} else if (import_saved_model_v1) {
|
||||||
|
std::unordered_set<std::string> tags =
|
||||||
|
absl::StrSplit(saved_model_tags, ',');
|
||||||
|
mlir::MLIRContext context;
|
||||||
|
|
||||||
|
auto module =
|
||||||
|
tensorflow::SavedModelV1ToMlirImport(input_filename, tags, &context);
|
||||||
|
if (!module) return 1;
|
||||||
|
|
||||||
module->print(output->os());
|
module->print(output->os());
|
||||||
} else {
|
} else {
|
||||||
auto input = mlir::openInputFile(input_filename, &error_message);
|
auto input = mlir::openInputFile(input_filename, &error_message);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user