From d25dd807485c6e57a200fcf9622a974f24ad1a2f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 3 Jan 2020 13:46:45 -0800 Subject: [PATCH] Initial version of SavedModel V1 Importer that converts a V1 SavedModel to a MLIR Module that contains functions specified by signature defs. PiperOrigin-RevId: 288042933 Change-Id: I5dfde397eb8635020025aa1dc6fee690e4b45ae3 --- tensorflow/compiler/mlir/python/mlir.i | 53 +++ tensorflow/compiler/mlir/tensorflow/BUILD | 3 + .../tensorflow/tests/tf_saved_model/BUILD | 14 +- .../tests/tf_saved_model/basic_v1.py | 64 ++++ .../tests/tf_saved_model/build_defs.bzl | 1 + .../tests/tf_saved_model/common_v1.py | 93 ++++++ .../tf_saved_model/shared_variable_v1.py | 64 ++++ .../mlir/tensorflow/translate/import_model.cc | 305 +++++++++++++++++- .../mlir/tensorflow/translate/import_model.h | 7 + .../tensorflow/translate/tf_mlir_translate.cc | 21 ++ .../tensorflow/translate/tf_mlir_translate.h | 8 + .../compiler/mlir/tf_mlir_translate_main.cc | 21 +- 12 files changed, 646 insertions(+), 8 deletions(-) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py diff --git a/tensorflow/compiler/mlir/python/mlir.i b/tensorflow/compiler/mlir/python/mlir.i index 2ecea47b3d3..b1d53288204 100644 --- a/tensorflow/compiler/mlir/python/mlir.i +++ b/tensorflow/compiler/mlir/python/mlir.i @@ -108,6 +108,45 @@ string ExperimentalConvertSavedModelToMlir( return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info); } +// Load a SavedModel V1 and return a textual MLIR string corresponding to it. +// +// Args: +// saved_model_path: File path from which to load the SavedModel. +// tags: Tags to identify MetaGraphDef that need to be loaded. +// +// Returns: +// A string of textual MLIR representing the raw imported SavedModel. +string ExperimentalConvertSavedModelV1ToMlir( + const string &saved_model_path, + const string &tags, + bool show_debug_info, + TF_Status* status) { + // Load the saved model into a SavedModelBundle. + + std::unordered_set tag_set + = absl::StrSplit(tags, ',', absl::SkipEmpty()); + + tensorflow::SavedModelBundle bundle; + auto load_status = tensorflow::LoadSavedModel( + {}, {}, + saved_model_path, tag_set, &bundle); + if (!load_status.ok()) { + Set_TF_Status_from_Status(status, load_status); + return "// error"; + } + + // Convert the SavedModelBundle to an MLIR module. + + mlir::MLIRContext context; + auto module_or = ConvertSavedModelV1ToMlir(bundle, &context); + if (!module_or.status().ok()) { + Set_TF_Status_from_Status(status, module_or.status()); + return "// error"; + } + + return MlirModuleToString(*module_or.ConsumeValueOrDie(), show_debug_info); +} + string ExperimentalRunPassPipeline( const string &mlir_txt, @@ -154,6 +193,7 @@ string ExperimentalRunPassPipeline( %unignore tensorflow::swig; %unignore tensorflow::swig::ImportGraphDef; %unignore tensorflow::swig::ExperimentalConvertSavedModelToMlir; +%unignore tensorflow::swig::ExperimentalConvertSavedModelV1ToMlir; %unignore tensorflow::swig::ExperimentalRunPassPipeline; // Wrap this function @@ -167,6 +207,11 @@ static string ExperimentalConvertSavedModelToMlir( const string &exported_names, bool show_debug_info, TF_Status* status); +static string ExperimentalConvertSavedModelV1ToMlir( + const string &saved_model_path, + const string &tags, + bool show_debug_info, + TF_Status* status); static string ExperimentalRunPassPipeline( const string &mlir_txt, const string &pass_pipeline, @@ -188,6 +233,14 @@ def experimental_convert_saved_model_to_mlir(saved_model_path, show_debug_info ).decode('utf-8'); +def experimental_convert_saved_model_v1_to_mlir(saved_model_path, + tags, show_debug_info): + return ExperimentalConvertSavedModelV1ToMlir( + str(saved_model_path).encode('utf-8'), + str(tags).encode('utf-8'), + show_debug_info + ).decode('utf-8'); + def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info): return ExperimentalRunPassPipeline( mlir_txt.encode('utf-8'), diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index a1710bf1f4a..2888997c7b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -348,15 +348,18 @@ cc_library( ":tensorflow", ":tensorflow_passes", "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/utils:transitive_fanin", "//tensorflow/core/platform:types", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD index abad9b7e916..93ee05d478e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD @@ -13,6 +13,15 @@ py_library( ], ) +py_library( + name = "common_v1", + srcs = ["common_v1.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + filegroup( name = "test_utilities", testonly = True, @@ -24,7 +33,10 @@ filegroup( # Drop trailing ".py" from all test file names. all_test_basenames = [py[:-3] for py in glob( ["*.py"], - exclude = ["common.py"], + exclude = [ + "common.py", + "common_v1.py", + ], )] # Instantiate all the tests. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py new file mode 100644 index 00000000000..8fb8b4e6e2d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/basic_v1.py @@ -0,0 +1,64 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/basic_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> () +# CHECK: func @basic([[ARG0:%.*]]: tensor<3x1xf32>, +# CHECK-SAME: [[ARG1:%.*]]: tensor>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32> +# CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor>>) -> tensor<1x3xf32> +# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> +# CHECK-NEXT: return [[R1]] : tensor<3x3xf32> + + +def Test(): + + # Default TF1.x uses reference variables that are not supported by SavedModel + # v1 Importer. To use SavedModel V1 Importer, resource variables should be + # enabled. + tf.compat.v1.enable_resource_variables() + + tf.compat.v1.disable_eager_execution() + + x = tf.constant([[1.0], [1.0], [1.0]]) + y = tf.compat.v1.get_variable( + name='y', + shape=(1, 3), + initializer=tf.random_normal_initializer(), + trainable=True) + r = tf.matmul(x, y) + + tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) + tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r) + + return { + 'basic': + (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={'x': tensor_info_x}, + outputs={'r': tensor_info_r}, + method_name=tf.saved_model.PREDICT_METHOD_NAME)) + } + + +if __name__ == '__main__': + common_v1.do_test(Test()) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl index 4fc49613abc..0e83900d98c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl @@ -11,6 +11,7 @@ def tf_saved_model_test(name, data): srcs = [name + ".py"], deps = [ "//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common", + "//tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model:common_v1", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py new file mode 100644 index 00000000000..35858d2b38a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -0,0 +1,93 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Serves as a common "main" function for all the SavedModel tests. + +There is a fair amount of setup needed to initialize tensorflow and get it +into a proper TF2 execution mode. This hides that boilerplate. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +from absl import app +from absl import flags +from absl import logging +import tensorflow.compat.v1 as tf + +from tensorflow.python import pywrap_tensorflow + +# Use /tmp to make debugging the tests easier (see README.md) +flags.DEFINE_string('save_model_path', '', 'Path to save the model to.') +FLAGS = flags.FLAGS + + +# This function needs to take a "create_module_fn", as opposed to just the +# module itself, because the creation of the module has to be delayed until +# after absl and tensorflow have run various initialization steps. +def do_test(signature_def_map, show_debug_info=False): + """Runs test. + + 1. Performs absl and tf "main"-like initialization that must run before almost + anything else. + 2. Converts signature_def_map to SavedModel V1 + 3. Converts SavedModel V1 to MLIR + 4. Prints the textual MLIR to stdout (it is expected that the caller will have + FileCheck checks in its file to check this output). + + This is only for use by the MLIR SavedModel importer tests. + + Args: + signature_def_map: A map from string key to signature_def. The key will be + used as function name in the resulting MLIR. + show_debug_info: If true, shows debug locations in the resulting MLIR. + """ + + # Make LOG(ERROR) in C++ code show up on the console. + # All `Status` passed around in the C++ API seem to eventually go into + # `LOG(ERROR)`, so this makes them print out by default. + logging.set_stderrthreshold('error') + + def app_main(argv): + """Function passed to absl.app.run.""" + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + if FLAGS.save_model_path: + save_model_path = FLAGS.save_model_path + else: + save_model_path = tempfile.mktemp(suffix='.saved_model') + + sess = tf.Session() + sess.run(tf.initializers.global_variables()) + builder = tf.saved_model.builder.SavedModelBuilder(save_model_path) + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map, + strip_default_attrs=True) + builder.save() + + logging.info('Saved model to: %s', save_model_path) + mlir = pywrap_tensorflow.experimental_convert_saved_model_v1_to_mlir( + save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]), + show_debug_info) + # We don't strictly need this, but it serves as a handy sanity check + # for that API, which is otherwise a bit annoying to test. + # The canonicalization shouldn't affect these tests in any way. + mlir = pywrap_tensorflow.experimental_run_pass_pipeline( + mlir, 'tf-standard-pipeline', show_debug_info) + print(mlir) + + app.run(app_main) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py new file mode 100644 index 00000000000..6ba51c2a325 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/shared_variable_v1.py @@ -0,0 +1,64 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/shared_variable_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# CHECK: "tf_saved_model.global_tensor"() {is_mutable, sym_name = "y", type = tensor<1x3xf32>, value = {{.*}} : tensor<1x3xf32>} : () -> () +# CHECK: func {{@.*}}([[ARG0:%.*]]: tensor<3x1xf32>, +# CHECK-SAME: [[ARG1:%.*]]: tensor>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32> + +# CHECK: func {{@.*}}([[ARG2:%.*]]: tensor<3x1xf32>, +# CHECK-SAME: [[ARG3:%.*]]: tensor>> {tf_saved_model.bound_input = @y}) -> tensor<3x3xf32> + + +def Test(): + + # Default TF1.x uses reference variables that are not supported by SavedModel + # v1 Importer. To use SavedModel V1 Importer, resource variables should be + # enabled. + tf.enable_resource_variables() + + tf.compat.v1.disable_eager_execution() + + x = tf.constant([[1.0], [1.0], [1.0]]) + y = tf.get_variable( + name='y', + shape=(1, 3), + initializer=tf.random_normal_initializer(), + trainable=True) + r = tf.matmul(x, y) + + tensor_info_x = tf.saved_model.utils.build_tensor_info(x) + tensor_info_r = tf.saved_model.utils.build_tensor_info(r) + + signature_def = tf.saved_model.signature_def_utils.build_signature_def( + inputs={'x': tensor_info_x}, + outputs={'r': tensor_info_r}, + method_name=tf.saved_model.PREDICT_METHOD_NAME) + + # Create two signatures that share the same variable. + return {'basic': signature_def, 'basic_2': signature_def} + + +if __name__ == '__main__': + common_v1.do_test(Test()) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index cf96e9b3789..0f258495f47 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -35,6 +35,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" @@ -71,6 +72,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" @@ -81,6 +83,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils/transitive_fanin.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" @@ -1734,8 +1737,8 @@ class GraphDefImporter : public ImporterBase { static StatusOr Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, - const GraphImportConfig& specs); + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, + llvm::StringRef func_name); private: explicit GraphDefImporter( @@ -1773,7 +1776,7 @@ class GraphDefImporter : public ImporterBase { StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, - const GraphImportConfig& specs) { + const GraphImportConfig& specs, llvm::StringRef func_name) { mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; @@ -1861,7 +1864,7 @@ StatusOr GraphDefImporter::Convert( {producer, min_consumer, bad_consumers}))); TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( - "main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, + func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, resource_arg_unique_ids)); return module; } @@ -2771,6 +2774,292 @@ StatusOr SavedModelImporter::Convert( return module; } +// A helper class to import a TensorFlow model expressed in SavedModel V1 into +// an MLIR Module. +class SavedModelV1Importer { + public: + // Main entry point: converts all functions (specified by SignatureDefs) in + // the given meta graph to an MLIR Module. + static StatusOr Convert(const SavedModelBundle& bundle, + mlir::MLIRContext* context) { + SavedModelV1Importer importer(bundle, context); + + return importer.ConvertSignatures(); + } + + private: + SavedModelV1Importer(const SavedModelBundle& bundle, + mlir::MLIRContext* context) + : bundle_(bundle), + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} + + // Convert the SavedModel to TF Executor Dialect. It creates a MLIR function + // for each signature. + StatusOr ConvertSignatures(); + StatusOr ConvertSignature( + const GraphImportConfig& specs, llvm::StringRef func_name, + const SignatureDef& signature_def, const GraphDef& sub_graph_def, + const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def); + + // Create GlobalTensorOp for each variable and move each VarHandle op to + // the enclosing function's arugments. + Status LiftVariables(); + void LiftVariable(mlir::TF::VarHandleOp op); + + // Read all variables from the SavedModel through session, and create + // GlobalTensorOp for these variables. + Status ReadVariablesFromSession( + const llvm::SmallVectorImpl& ops); + + GraphImportConfig::InputArrays ParseInputArrays( + const tensorflow::protobuf::Map& inputs); + + std::vector ParseOutputArrays( + const tensorflow::protobuf::Map& outputs); + + const SavedModelBundle& bundle_; + mlir::OwningModuleRef module_; +}; + +// Convert the SavedModel to TF Executor Dialect. It creates a MLIR function +// for each signature. +StatusOr SavedModelV1Importer::ConvertSignatures() { + const auto& signatures = bundle_.GetSignatures(); + const auto& graphdef = bundle_.meta_graph_def.graph_def(); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library()); + + // debug_info might not be loaded with loader_lite. + GraphDebugInfo debug_info; + if (bundle_.debug_info != nullptr) debug_info = *bundle_.debug_info; + + for (const auto& key_and_signature_def : signatures) { + const auto& func_name = key_and_signature_def.first; + const auto& signature_def = key_and_signature_def.second; + GraphImportConfig specs; + specs.inputs = ParseInputArrays(signature_def.inputs()); + specs.outputs = ParseOutputArrays(signature_def.outputs()); + + // Remove unused nodes and create a sub graphdef. + GraphDef sub_graph_def; + TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph( + graphdef, &sub_graph_def, + /* terminal_nodes = */ {specs.outputs.begin(), specs.outputs.end()})); + + auto status_or_sub_module = ConvertSignature( + specs, func_name, signature_def, sub_graph_def, debug_info, flib_def); + if (!status_or_sub_module.ok()) { + LOG(ERROR) << "Failed to convert SignatureDef for " << func_name << ": " + << status_or_sub_module.status(); + continue; + } + + auto& sub_module = status_or_sub_module.ValueOrDie(); + + // Move the converted functions to top level MLIR module. + auto* block = module_->getBody(); + auto* sub_block = sub_module->getBody(); + block->getOperations().splice( + mlir::Block::iterator(block->getTerminator()), + sub_block->getOperations(), sub_block->begin(), + mlir::Block::iterator(sub_block->getTerminator())); + } + + TF_RETURN_IF_ERROR(LiftVariables()); + + return std::move(module_); +} + +StatusOr SavedModelV1Importer::ConvertSignature( + const GraphImportConfig& specs, llvm::StringRef func_name, + const SignatureDef& signature_def, const GraphDef& sub_graph_def, + const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def) { + // Convert this sub graphdef to sub graph + GraphConstructorOptions options; + options.allow_internal_ops = true; + options.add_default_attributes = true; + Graph sub_graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR( + ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph)); + + // Convert the sub graphdef to a MLIR function. + return GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info, + flib_def, specs, func_name); +} + +// Create GlobalTensorOp for each variable and move each VarHandle op to +// the enclosing function's arugments. +Status SavedModelV1Importer::LiftVariables() { + llvm::SmallVector ops; + + bool contains_ref_variable = false; + + module_->walk([&ops, &contains_ref_variable](mlir::Operation* op) { + if (auto var_handle_op = llvm::dyn_cast(op)) + ops.push_back(var_handle_op); + else if (op->getName().getStringRef() == "tf.VariableV2") + contains_ref_variable = true; + }); + + if (contains_ref_variable) + return errors::InvalidArgument( + "Ref variable created by VariableV2 is not supported."); + + if (ops.empty()) return Status::OK(); + + TF_RETURN_IF_ERROR(ReadVariablesFromSession(ops)); + + for (auto op : ops) LiftVariable(op); + + return Status::OK(); +} + +// Move the result of the VarHandleOp to the enclosing function's arugment list +// and erase this VarHandleOp. +void SavedModelV1Importer::LiftVariable(mlir::TF::VarHandleOp op) { + mlir::OpBuilder builder(&module_->getBodyRegion()); + + auto func_op = op.getParentOfType(); + builder.setInsertionPoint(func_op); + + auto func_type = func_op.getType(); + + // Create the new function type by adding variable type to the arguments. + llvm::SmallVector new_input_types( + func_type.getInputs().begin(), func_type.getInputs().end()); + new_input_types.push_back(op.resource()->getType()); + auto new_func_type = + builder.getFunctionType(new_input_types, func_type.getResults()); + + auto new_func_op = builder.create( + func_op.getLoc(), func_op.getName(), new_func_type, + llvm::ArrayRef()); + + // Bind the argument to the corresponding global tensor op. + new_func_op.setArgAttr(new_func_op.getNumArguments() - 1, + "tf_saved_model.bound_input", + builder.getSymbolRefAttr(op.shared_name())); + + // Replace the function body and update its signature. + auto& new_region = new_func_op.getBody(); + new_region.getBlocks().splice(new_region.end(), + func_op.getBody().getBlocks()); + + func_op.getOperation()->erase(); + + auto& new_block = new_region.front(); + auto new_value = new_block.addArgument(op.resource()->getType()); + + op.getOperation()->replaceAllUsesWith(llvm::ArrayRef(new_value)); + + op.getOperation()->erase(); +} + +// Read all variables from the SavedModel through session, and create +// GlobalTensorOp for these variables. +Status SavedModelV1Importer::ReadVariablesFromSession( + const llvm::SmallVectorImpl& ops) { + mlir::OpBuilder builder(&module_->getBodyRegion()); + + // Find all variables and their corresponding read ops. + + llvm::MapVector + variable_names_and_ops; + for (auto op : ops) { + variable_names_and_ops[op.shared_name()] = op; + } + + // Read all resource variables from the session. + + std::vector variable_names; + variable_names.reserve(variable_names_and_ops.size()); + for (const auto& name_and_location : variable_names_and_ops) + variable_names.push_back(name_and_location.first); + + std::vector resource_tensors; + TF_RETURN_IF_ERROR(bundle_.GetSession()->Run( + /*inputs=*/{}, variable_names, + /*target_node_names=*/{}, &resource_tensors)); + + const DeviceMgr* device_manager; + TF_RETURN_IF_ERROR(bundle_.GetSession()->LocalDeviceManager(&device_manager)); + + // Read all underlying tensors of the variables from the session. + std::vector tensors; + tensors.reserve(resource_tensors.size()); + for (const auto& resource_tensor : resource_tensors) { + const auto& resource_handle = resource_tensor.scalar()(); + + Device* device; + TF_RETURN_IF_ERROR( + device_manager->LookupDevice(resource_handle.device(), &device)); + + Var* var_ptr; + TF_RETURN_IF_ERROR(device->resource_manager()->Lookup( + resource_handle.container(), resource_handle.name(), &var_ptr)); + core::RefCountPtr var(var_ptr); + + // The variable tensor is already loaded into corresponding device's + // resource manager when we load the saved model using LoadSavedModel(). + // Here we just read its value. + mutex_lock ml(*var->mu()); + tensors.push_back(*var->tensor()); + } + + for (const auto& iter : llvm::zip(variable_names_and_ops, tensors)) { + const auto& name = std::get<0>(iter).first; + auto location = std::get<0>(iter).second.getLoc(); + const auto& tensor = std::get<1>(iter); + + // Create tensor attribute for this variable. + TF_ASSIGN_OR_RETURN(auto tensor_attr, ConvertTensor(tensor, &builder)); + + builder.create( + location, builder.getStringAttr(name), tensor_attr, + mlir::TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr()); + } + + return Status::OK(); +} + +GraphImportConfig::InputArrays SavedModelV1Importer::ParseInputArrays( + const tensorflow::protobuf::Map& inputs) { + GraphImportConfig::InputArrays results; + for (const auto& iter : inputs) { + const auto& tensor_info = iter.second; + + // Only dense tensor is supported. + DCHECK_EQ(tensor_info.encoding_case(), tensorflow::TensorInfo::kName); + + ArrayInfo array_info; + array_info.imported_dtype = tensor_info.dtype(); + array_info.shape = tensor_info.tensor_shape(); + + std::vector node_names = + absl::StrSplit(tensor_info.name(), ':'); + + results.insert(std::pair(node_names.at(0), + std::move(array_info))); + } + return results; +} + +std::vector SavedModelV1Importer::ParseOutputArrays( + const tensorflow::protobuf::Map& outputs) { + std::vector results; + for (const auto& iter : outputs) { + const auto& tensor_info = iter.second; + + std::vector node_names = + absl::StrSplit(tensor_info.name(), ':'); + results.push_back(node_names.at(0)); + } + return results; +} + } // namespace Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def) { @@ -2806,7 +3095,8 @@ StatusOr ConvertGraphToMlir( UpgradeLegacyGraph(const_cast(&graph), const_cast(&flib_def))); } - return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs); + return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs, + /* func_name = */ "main"); } StatusOr ConvertSavedModelToMlir( @@ -2816,6 +3106,11 @@ StatusOr ConvertSavedModelToMlir( add_default_attributes); } +StatusOr ConvertSavedModelV1ToMlir( + const SavedModelBundle& saved_model, mlir::MLIRContext* context) { + return SavedModelV1Importer::Convert(saved_model, context); +} + std::string MlirModuleToString(mlir::ModuleOp module, bool show_debug_info) { std::string txt_module; { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index 9f04d8aa782..efc316483fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Module.h" // TF:llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -50,6 +51,12 @@ stream_executor::port::StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes = true); +// Given a V1 SavedModel, returns a MLIR module containing the functions, +// expressed with tf_executor dialect. +stream_executor::port::StatusOr +ConvertSavedModelV1ToMlir(const SavedModelBundle& saved_model, + mlir::MLIRContext* context); + // Serialize a MLIR module to a string. std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index f7cf5377bb8..8f3cab0e619 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -130,6 +130,27 @@ mlir::OwningModuleRef SavedModelToMlirImport( return module_or.ConsumeValueOrDie(); } +mlir::OwningModuleRef SavedModelV1ToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, mlir::MLIRContext* context) { + tensorflow::SavedModelBundle bundle; + auto load_status = tensorflow::LoadSavedModel( + /* session_options = */ {}, /* run_options = */ {}, + std::string(saved_model_dir), tags, &bundle); + if (!load_status.ok()) { + LOG(ERROR) << "Failed to load saved model v1 '" << saved_model_dir + << "': " << load_status; + return nullptr; + } + + auto module_or = ConvertSavedModelV1ToMlir(bundle, context); + if (!module_or.status().ok()) { + LOG(ERROR) << "SavedModel V1 import failed: " << module_or.status(); + return nullptr; + } + return module_or.ConsumeValueOrDie(); +} + mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_dtypes, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index ea5dfffe66e..46e6376207c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -54,6 +54,14 @@ mlir::OwningModuleRef SavedModelToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context); + +// Converts a TensorFlow V1 SavedModel stored in the directory with the given +// `saved_model_dir` into a MLIR module. Creates MLIR entities into the +// given MLIR `context`. +mlir::OwningModuleRef SavedModelV1ToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, mlir::MLIRContext* context); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 7e71a1770c7..f5fc56556ec 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -54,6 +54,12 @@ static llvm::cl::opt import_saved_model( llvm::cl::desc("Import a saved model to its MLIR representation"), llvm::cl::value_desc("dir")); +// NOLINTNEXTLINE +static llvm::cl::opt import_saved_model_v1( + "savedmodel-v1-to-mlir", + llvm::cl::desc("Import a saved model V1 to its MLIR representation"), + llvm::cl::value_desc("dir")); + // NOLINTNEXTLINE static llvm::cl::opt saved_model_tags( "tf-savedmodel-tags", @@ -77,10 +83,11 @@ int main(int argc, char** argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n"); - if (!import_saved_model && !requested_translation) { + if (!import_saved_model && !import_saved_model_v1 && !requested_translation) { llvm::errs() << "error: need to specify one translation to perform\n"; return 1; - } else if (import_saved_model && requested_translation) { + } else if (import_saved_model && import_saved_model_v1 && + requested_translation) { llvm::errs() << "error: cannot specify more than one translation to perform\n"; return 1; @@ -105,6 +112,16 @@ int main(int argc, char** argv) { &context); if (!module) return 1; + module->print(output->os()); + } else if (import_saved_model_v1) { + std::unordered_set tags = + absl::StrSplit(saved_model_tags, ','); + mlir::MLIRContext context; + + auto module = + tensorflow::SavedModelV1ToMlirImport(input_filename, tags, &context); + if (!module) return 1; + module->print(output->os()); } else { auto input = mlir::openInputFile(input_filename, &error_message);