Make function names unchanged during Graphdef to MLIR translation.

Previously, a suffix (usually "0") is added to each function during Graphdef
to MLIR import. With the change, a round-trip of Graph->MLIR->Graph or
MLIR->Graph->MLIR shouldn't change any function names.

Other notable changes:

- Export functions now clear the function library passed in to avoid function
 name conflicts.

-  The V1 bridge doesn't like an optimization pipeline to change library
 function in-place, i.e. create different functions that have the same
 names as the functions in the original library. Therefore,
 the CL adds a "rename_private_functions" pass in the pipeline.

- translation a library function to MLIR function can be done independently from converting the Graph, so this change also lifts the translation to the top level (`GraphDefImporter::Convert`) to simplify the implimentation.

PiperOrigin-RevId: 329071572
Change-Id: Ib33bd492da53c9764f552ef42842efbdcd1a54d7
This commit is contained in:
Jing Pu 2020-08-29 00:39:31 -07:00 committed by TensorFlower Gardener
parent d98960155a
commit a32c74ae8f
17 changed files with 329 additions and 125 deletions

View File

@ -768,6 +768,7 @@ cc_library(
"transforms/promote_resources_to_args.cc",
"transforms/readonly_references_to_resources.cc",
"transforms/region_control_flow_to_functional.cc",
"transforms/rename_private_functions.cc",
"transforms/replicate_invariant_op_hoisting.cc",
"transforms/replicate_to_island.cc",
"transforms/resource_device_inference.cc",
@ -1110,13 +1111,16 @@ cc_library(
":convert_graphdef",
":error_util",
":mlir_roundtrip_flags",
":tensorflow_passes",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

View File

@ -25,16 +25,17 @@ func @foo() {
// CHECK: func @main()
// CHECK: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]]
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
// CHECK: func @[[ELSE_FUNC:[A-Za-z0-9_]*]]
// CHECK: "tf.Add"
// CHECK: func @[[THEN_FUNC:[A-Za-z0-9_]*]]
// CHECK: "tf.Mul"
// In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch.
// CHECK: func @[[FUNCTIONALIZE_FUNC]]
// CHECK: "tf.If"
// CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]]
// CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]]
// CHECK-SAME: else_branch = @[[ELSE_FUNC]]
// CHECK-SAME: then_branch = @[[THEN_FUNC]]
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
// CHECK: func @[[ELSE_FUNC]]
// CHECK: "tf.Add"
// CHECK: func @[[THEN_FUNC]]
// CHECK: "tf.Mul"

View File

@ -54,5 +54,5 @@ versions {
# the names are matching between the function definition and the uses / call
# site (a numerical suffix may be appended).
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo0}
# CHECK: func @foo0
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo}
# CHECK: func @foo

View File

@ -68,4 +68,4 @@ library {
}
# CHECK: func @main
# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @test_func_name0}
# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @test_func_name}

View File

@ -3,7 +3,7 @@
# Verify that the _input_shapes attribute of the FunctionDef is respected.
# This also checks that the output type is correctly inferred based on
# that.
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
#CHECK: func @identity_function(%arg0: tensor<i32>) -> tensor<i32>
node {
name: "Placeholder"

View File

@ -121,8 +121,8 @@ versions {
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo110}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo111}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo1}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo11}
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo1() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo11() attributes {sym_visibility = "private"}

View File

@ -88,10 +88,10 @@ library {
# CHECK: tf_executor.graph
# CHECK: "tf.VarHandleOp"()
# CHECK: "tf.LegacyCall"
# CHECK-SAME: {_disable_call_shape_inference = true, device = "", f = @test_func_name0}
# CHECK-SAME: {_disable_call_shape_inference = true, device = "", f = @test_func_name}
# CHECK: tf_executor.fetch
# CHECK: return
# CHECK: func @test_func_name0
# CHECK: func @test_func_name
# CHECK-SAME: tf._resource_arg_unique_id = 0
# CHECK-SAME: tf._resource_arg_unique_id = 0
# CHECK: tf_executor.graph

View File

@ -4,7 +4,7 @@
# links the function and its gradient. In MLIR a TF ops gradient function is
# added to its list of function attributes.
# CHECK: func @foo0(
# CHECK: func @foo(
# CHECK: tf.gradient = @foo_grad
node {

View File

@ -54,10 +54,11 @@ versions {
# Verify that functions from the library are properly imported.
# CHECK-LABEL: func @main() {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar}
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
# CHECK-LABEL: func @bar() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo() attributes {sym_visibility = "private"}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar}
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}

View File

@ -0,0 +1,53 @@
// RUN: tf-opt %s -allow-unregistered-dialect -split-input-file -tf-rename-private-functions | FileCheck %s
// CHECK-LABEL: @simple
func @simple() {
// CHECK: "my.call"() {func = @[[NEW_FUNC_NAME:.+]]}
"my.call"() {func = @my_func} : () -> ()
return
}
// CHECK-NOT: func @my_func()
// CHECK: func @[[NEW_FUNC_NAME]]()
func @my_func() -> () attributes {sym_visibility = "private"}
// -----
// A stress test case to test uniquification logic
// CHECK-LABEL: @test_uniquification
func @test_uniquification() {
// CHECK: "my.call"() {func = @[[NEW_FUNC_NAME_0:.+]]}
"my.call"() {func = @my_func} : () -> ()
// CHECK: "my.call"() {func = @[[NEW_FUNC_NAME_1:.+]]}
"my.call"() {func = @my_func0} : () -> ()
return
}
// CHECK-NOT: func @my_func()
// CHECK-NOT: func @my_func0()
// CHECK: func @[[NEW_FUNC_NAME_0]]()
func @my_func() -> () attributes {sym_visibility = "private"}
// CHECK: func @[[NEW_FUNC_NAME_1]]()
func @my_func0() -> () attributes {sym_visibility = "private"}
// -----
// Test for SymbolRefArrayAttr
// CHECK-LABEL: @test_case_op
func @test_case_op(%arg0: tensor<i32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_one, @branch_two], is_stateless = false} : (tensor<i32>, tensor<2xf32>) -> tensor<2xf32>
// CHECK: "tf.Case"(%arg0, %arg1) {branches = [@[[NEW_FUNC_NAME_1:.+]], @[[NEW_FUNC_NAME_2:.+]]]
return %0 : tensor<2xf32>
}
// CHECK-NOT: func @branch_one()
// CHECK-NOT: func @branch_two()
// CHECK: func @[[NEW_FUNC_NAME_1]]
func @branch_one(tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
// CHECK: func @[[NEW_FUNC_NAME_2]]
func @branch_two(tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}

View File

@ -122,6 +122,10 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
}
void CreateTPUBridgePipelineV1(OpPassManager &pm) {
// Function library in the V1 bridge is likely to contain many unused
// functions, so we remove them early to speed up the rest of the pipeline.
pm.addPass(createSymbolDCEPass());
pm.addPass(TF::CreateTFShapeInferencePass());
// For V1 compatibility, we process a module where the graph does not have
// feeds and fetched. We extract first the TPU computation in a submodule,
@ -132,6 +136,19 @@ void CreateTPUBridgePipelineV1(OpPassManager &pm) {
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
OpPassManager &nested_module = pm.nest<ModuleOp>();
CreateTPUBridgePipeline(nested_module);
// The TF runtime using the V1 bridge doesn't like an optimization pipeline to
// change library functions in-place, i.e. create different functions that
// have the same names as the functions in the original function library. Some
// of this constraint come from the fact that Session can extend its function
// library with the output function library of the bridge and equality checks
// of FunctionDef's are based on exact contents which is not guaranteed by the
// TF importer/exporter nor by the V1 bridge.
//
// Therefore, we rename all these function to new names to avoid any failures
// in Session::Extend.
pm.addPass(TF::CreateRenamePrivateFunctionPass());
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
}

View File

@ -173,6 +173,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
// Creates function pass to replace InitializeTableFromTextFileV2Ops with
// LookupTableImportV2Op ops.
std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
// Creates a module pass that renames all the private functions to new
// names that don't exist in the original module.
std::unique_ptr<OperationPass<ModuleOp>> CreateRenamePrivateFunctionPass();
} // namespace TF
namespace tf_executor {

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
namespace mlir {
namespace TF {
namespace {
// A helper class that generates name strings that are both uniques among
// a pre-defined set of existing strings and among the new strings it generates.
class NameUniquifier : public tensorflow::OpOrArgNameMapper {
public:
explicit NameUniquifier(const llvm::StringSet<> &existing_names)
: existing_names_(existing_names) {}
private:
bool IsUnique(llvm::StringRef name) override {
return !existing_names_.contains(name);
}
std::string GetName(tensorflow::OpOrVal op_or_val) override {
llvm_unreachable("This method shouldn't be used.");
return "";
}
const llvm::StringSet<> &existing_names_;
};
// Returns an updated SymbolRefAttr according to `symbol_renaming_map`.
// If the symbol name is not in the map, then the function returns the `old`
// SymbolRefAttr.
SymbolRefAttr GetUpdatedSymbolRefAttr(
SymbolRefAttr old, const llvm::StringMap<StringRef> &symbol_renaming_map) {
auto it = symbol_renaming_map.find(old.getRootReference());
if (it == symbol_renaming_map.end()) {
return old;
}
StringRef new_symbol_name = it->second;
return SymbolRefAttr::get(new_symbol_name, old.getNestedReferences(),
old.getContext());
}
// A pass that renames all the private functions to new names that don't exist
// in the original module.
struct RenamePrivateFunctionPass
: public PassWrapper<RenamePrivateFunctionPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
void RenamePrivateFunctionPass::runOnOperation() {
ModuleOp module = getOperation();
// Get all old function names
llvm::StringSet<> old_private_func_names;
for (auto func : module.getOps<FuncOp>()) {
old_private_func_names.insert(func.getName());
}
// Update private function names
NameUniquifier name_uniquifier(old_private_func_names);
llvm::StringMap<StringRef> func_name_map;
for (auto func : module.getOps<FuncOp>()) {
if (func.isPrivate()) {
StringRef old_name = func.getName();
StringRef new_name = name_uniquifier.GetUniqueName(old_name);
func.setName(new_name);
func_name_map.insert(std::make_pair(old_name, new_name));
}
}
// Update any SymbolRefAttr
module.walk([&func_name_map](Operation *op) {
for (NamedAttribute p : op->getAttrs()) {
Identifier id = p.first;
Attribute attr = p.second;
if (auto symbol_ref = attr.dyn_cast<SymbolRefAttr>()) {
SymbolRefAttr new_symbol_ref =
GetUpdatedSymbolRefAttr(symbol_ref, func_name_map);
if (new_symbol_ref != symbol_ref) {
op->setAttr(id, new_symbol_ref);
}
} else if (auto array_attr = attr.dyn_cast<ArrayAttr>()) {
// Update any SymbolRefAttr in the ArrayAttr
SmallVector<Attribute, 4> new_array;
new_array.reserve(array_attr.size());
for (Attribute attr : array_attr.getValue()) {
if (auto symbol_ref = attr.dyn_cast<SymbolRefAttr>()) {
SymbolRefAttr new_symbol_ref =
GetUpdatedSymbolRefAttr(symbol_ref, func_name_map);
new_array.push_back(new_symbol_ref);
} else {
new_array.push_back(attr);
}
}
auto new_array_attr = ArrayAttr::get(new_array, op->getContext());
if (new_array_attr != array_attr) {
op->setAttr(id, new_array_attr);
}
}
}
});
}
PassRegistration<RenamePrivateFunctionPass> tpu_pass(
"tf-rename-private-functions",
"Renames all the private functions to new names");
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateRenamePrivateFunctionPass() {
return std::make_unique<RenamePrivateFunctionPass>();
}
} // namespace TF
} // namespace mlir

View File

@ -717,6 +717,8 @@ Status Exporter::Convert(mlir::ModuleOp module,
TF_ASSIGN_OR_RETURN(
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
control_ret_nodes));
flib_def->Clear();
for (auto& func_def : flib.function()) {
TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
}

View File

@ -38,8 +38,9 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
// The "main" function of the module is stored in the graph and the rest of
// functions are stored in the library. Control ret nodes are stored separately
// in `control_ret_nodes`.
// functions are stored in the library. Note that existing functions in the
// library will be deleted. Control ret nodes are stored separately in
// `control_ret_nodes`.
stream_executor::port::Status ConvertMlirToGraph(
mlir::ModuleOp module, const GraphExportConfig& configs,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
@ -47,7 +48,8 @@ stream_executor::port::Status ConvertMlirToGraph(
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
// The "main" function of the module is stored in the graph and the rest of
// functions are stored in the library.
// functions are stored in the library. Note that existing functions in the
// library will be deleted.
stream_executor::port::Status ConvertMlirToGraph(
mlir::ModuleOp module, const GraphExportConfig& configs,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
@ -149,34 +150,6 @@ void LoadImporterDialects(mlir::MLIRContext& context) {
registry.loadAll(&context);
}
// This class is used to generate new MLIR function name strings that are both
// unique in the TF function library `flib_` and unique among the name strings
// generated by the class object during its lifetime.
//
// In theory, this class is not necessary because we should simply take
// the TF function name and use it as MLIR function name. However, for some
// unknown reasons (callout for investigation in b/142268695), keeping the
// function names unchanged in an MLIR roundtrip causes test failures.
// TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
// and TF function name as MLIR function name after b/142268695 is root caused.
class NameUniquifier : public OpOrArgNameMapper {
public:
explicit NameUniquifier(const FunctionLibraryDefinition& flib)
: flib_(flib) {}
private:
bool IsUnique(llvm::StringRef name) override {
return !flib_.Contains(std::string(name));
}
std::string GetName(OpOrVal op_or_val) override {
DCHECK(false) << "Unimplemented";
return "";
}
const FunctionLibraryDefinition& flib_;
};
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
bool restrict_functionalization_to_tpu_nodes) {
// If `restrict_functionalization_to_tpu_nodes` is true let filter function
@ -202,18 +175,14 @@ class ImporterBase {
explicit ImporterBase(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const GraphImportConfig& specs, mlir::ModuleOp module,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
NameUniquifier* function_name_uniquifier,
llvm::StringRef function_name_for_debug_info = "")
: builder_(module.getContext()),
module_(module),
context_(module.getContext()),
tf_name_to_mlir_name_(tf_name_to_mlir_name),
graph_flib_(flib),
specs_(specs),
debug_info_(debug_info),
function_name_for_debug_info_(function_name_for_debug_info),
function_name_uniquifier_(function_name_uniquifier),
error_handler_(module.getContext()) {}
// Returns the inferred function signature of the given function body. Input
@ -429,7 +398,6 @@ class ImporterBase {
mlir::OpBuilder builder_;
mlir::ModuleOp module_;
mlir::MLIRContext* context_;
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
const FunctionLibraryDefinition& graph_flib_;
const GraphImportConfig& specs_;
const GraphDebugInfo& debug_info_;
@ -439,7 +407,6 @@ class ImporterBase {
// The shape_refinner_ will be nullptr if shape inference on import is
// not enabled.
std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
NameUniquifier* function_name_uniquifier_;
mlir::StatusScopedDiagnosticHandler error_handler_;
protected:
@ -1129,10 +1096,7 @@ Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
const std::string& func_name) {
TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
return builder_.getSymbolRefAttr(func);
return builder_.getSymbolRefAttr(func_name);
}
StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
@ -1225,16 +1189,6 @@ void ImporterBase::GetArgsAndRetsFromFunctionBody(
}
Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
// If the library function has been converted already, nothing needs to be
// done.
if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
tf_name_to_mlir_name_->end())
return Status::OK();
std::string mlir_func_name(
function_name_uniquifier_->GetUniqueName(func_name));
(*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
const auto& func_lib = graph_flib_;
const auto* func_def = func_lib.Find(std::string(func_name));
if (func_def == nullptr) {
@ -1272,10 +1226,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
// list of this function.
auto grad_func_name = func_lib.FindGradient(std::string(func_name));
if (!grad_func_name.empty()) {
TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
auto gradient_attr = builder_.getSymbolRefAttr(grad_func);
auto gradient_attr = builder_.getSymbolRefAttr(grad_func_name);
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
}
@ -1305,7 +1256,6 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
}
ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
tf_name_to_mlir_name_, function_name_uniquifier_,
func_name);
TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
@ -1319,7 +1269,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
&control_ret_nodes);
TF_RETURN_IF_ERROR(child_importer.Convert(
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
llvm::makeArrayRef(attributes.begin(), attributes.end())));
return Status::OK();
}
@ -1803,14 +1753,9 @@ Status ImporterBase::ConvertNode(const Node& node) {
// If it is a custom OP, its definition should be found in the library. We
// create the MLIR function and insert it to the module if it doesn't exist.
std::string node_type_name = node.type_string();
const std::string& node_type_name = node.type_string();
const auto* func_def = graph_flib_.Find(node_type_name);
bool convert_to_legacy_call = false;
if (func_def) {
TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
convert_to_legacy_call = true;
}
const bool convert_to_legacy_call = func_def != nullptr;
auto get_full_op_name = [&](const std::string& op_name) {
const char* kTfPrefix = "tf.";
@ -2097,21 +2042,21 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
// in the module.
class GraphDefImporter : public ImporterBase {
public:
// Main entry point: converts the given graph to an MLIR Module.
static StatusOr<mlir::OwningModuleRef> Convert(
// Main entry point: converts the given `graph` and library functions to an
// MLIR Module. `graph` is translated to a public function with `func_name`,
// while other functions in `flib_def` become private functions in the module.
static StatusOr<mlir::OwningModuleRef> ConvertGraphAndLibraryFunctions(
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info,
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
llvm::StringRef func_name);
private:
explicit GraphDefImporter(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const GraphImportConfig& specs, mlir::ModuleOp module,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
NameUniquifier* function_name_uniquifier)
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
function_name_uniquifier) {}
explicit GraphDefImporter(const FunctionLibraryDefinition& flib,
const GraphDebugInfo& debug_info,
const GraphImportConfig& specs,
mlir::ModuleOp module)
: ImporterBase(flib, debug_info, specs, module) {}
// Returns the function signature of the main function of converted MLIR
// module, the input nodes and output nodes. The type and shape information
@ -2140,18 +2085,16 @@ class GraphDefImporter : public ImporterBase {
absl::InlinedVector<Node*, 4>* control_ret_nodes);
};
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
StatusOr<mlir::OwningModuleRef>
GraphDefImporter::ConvertGraphAndLibraryFunctions(
mlir::MLIRContext* context, const Graph& graph,
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
const GraphImportConfig& specs, llvm::StringRef func_name) {
LoadImporterDialects(*context);
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
NameUniquifier function_name_uniquifier(flib_def);
GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
&tf_name_to_mlir_name, &function_name_uniquifier);
GraphDefImporter importer(flib_def, debug_info, specs, module.get());
TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
@ -2238,6 +2181,12 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
std::vector<std::string> lib_func_names = flib_def.ListFunctionNames();
std::sort(lib_func_names.begin(), lib_func_names.end());
for (const std::string& fn_name : lib_func_names) {
TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
}
// Mark main function public, others private.
for (auto function : module.get().getOps<mlir::FuncOp>()) {
auto visibility = function.getName() == func_name
@ -2459,13 +2408,11 @@ class SavedModelObjectGraphImporter : public ImporterBase {
mlir::MLIRContext* context, bool add_default_attributes);
private:
explicit SavedModelObjectGraphImporter(
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
const GraphImportConfig& specs, mlir::ModuleOp module,
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
NameUniquifier* function_name_uniquifier)
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
function_name_uniquifier) {}
explicit SavedModelObjectGraphImporter(const FunctionLibraryDefinition& flib,
const GraphDebugInfo& debug_info,
const GraphImportConfig& specs,
mlir::ModuleOp module)
: ImporterBase(flib, debug_info, specs, module) {}
};
// Determines the names used to reference objects in the SavedObjectGraph.
@ -3008,7 +2955,6 @@ void SortSavedModelModule(mlir::ModuleOp module) {
Status CreateSavedModelIR(
const ObjectNames& object_names, mlir::ModuleOp module,
const SavedObjectGraph& object_graph,
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
SavedModelV2Bundle* saved_model) {
mlir::OpBuilder builder(module.getBodyRegion());
mlir::SymbolTable symbol_table(module);
@ -3043,8 +2989,8 @@ Status CreateSavedModelIR(
"While importing SavedModel function '" +
object_names.GetExportedNames(node_id)[0].str() + "': ";
const SavedFunction& function = object.function();
auto orig_func = symbol_table.lookup<mlir::FuncOp>(
tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
auto orig_func =
symbol_table.lookup<mlir::FuncOp>(function.concrete_functions(0));
mlir::FuncOp func = orig_func;
// If there are potentially references to this func from within the
// module, create a wrapper around it and decorate the wrapper with the
@ -3210,7 +3156,6 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
specs.prune_unused_nodes = true;
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
const auto& graphdef = saved_model->meta_graph_def().graph_def();
PopulateTfVersions(module.get(), graphdef.versions());
@ -3228,10 +3173,8 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
NameUniquifier function_name_uniquifier(graph.flib_def());
SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
module.get(), &tf_name_to_mlir_name,
&function_name_uniquifier);
module.get());
TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
@ -3265,8 +3208,7 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
// Construct the SavedModel IR.
TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
object_graph, tf_name_to_mlir_name,
saved_model));
object_graph, saved_model));
assert(mlir::succeeded(mlir::verify(module.get())));
return module;
@ -3510,8 +3452,9 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefImporter::ConvertGraph(
specs.control_outputs = control_outputs;
// Convert sub-graph to MLIR module.true
return GraphDefImporter::Convert(module_->getContext(), graph(), debug_info(),
graph().flib_def(), specs, name);
return GraphDefImporter::ConvertGraphAndLibraryFunctions(
module_->getContext(), graph(), debug_info(), graph().flib_def(), specs,
name);
}
Status SavedModelSignatureDefImporter::ConvertSignature(
@ -3636,7 +3579,8 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
const_cast<FunctionLibraryDefinition*>(&flib_def),
specs.restrict_functionalization_to_tpu_nodes));
}
return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
return GraphDefImporter::ConvertGraphAndLibraryFunctions(
context, graph, debug_info, flib_def, specs,
/*func_name=*/"main");
}
@ -3651,11 +3595,16 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(),
&flib_def, &fbody));
// Create a copy of function library that doesn't contain the function `name`
// because it duplicates the `fbody` graph.
FunctionLibraryDefinition flib_def_copy = flib_def;
TF_RETURN_IF_ERROR(flib_def_copy.RemoveFunction(name.str()));
tensorflow::GraphDebugInfo dummy_debug_info;
tensorflow::GraphImportConfig specs;
specs.graph_as_function = true;
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
flib_def, specs, name);
return GraphDefImporter::ConvertGraphAndLibraryFunctions(
context, *fbody->graph, dummy_debug_info, flib_def_copy, specs, name);
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(

View File

@ -18,6 +18,10 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
@ -52,12 +56,36 @@ static Status Export(mlir::OwningModuleRef module,
const GraphOptimizationPassOptions& options,
std::unique_ptr<Graph>* graph) {
GraphExportConfig confs;
return ConvertMlirToGraph(*module, confs, graph, options.flib_def);
FunctionLibraryDefinition exported_function_library(OpRegistry::Global(), {});
TF_RETURN_IF_ERROR(
ConvertMlirToGraph(*module, confs, graph, &exported_function_library));
return options.flib_def->AddLibrary(exported_function_library);
}
static Status Roundtrip(const GraphOptimizationPassOptions& options,
std::unique_ptr<Graph>* graph, MLIRContext* context) {
TF_ASSIGN_OR_RETURN(auto module, Import(options, **graph, context));
{
// The TF runtime doesn't like an optimization pipeline
// to change library functions in-place, i.e. create different functions
// that have the same names as the functions in the original function
// library. Some of this constraint come from the fact that Session can
// extend its function library with the output function library of the
// bridge and equality checks of FunctionDef's are based on exact contents
// which is not guaranteed by the TF importer/exporter.
//
// Therefore, we rename all these function to new names to avoid any
// failures in Session::Extend.
mlir::PassManager pm(context);
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::TF::CreateRenamePrivateFunctionPass());
mlir::StatusScopedDiagnosticHandler status_handler(context);
if (mlir::failed(pm.run(module.get()))) {
return status_handler.ConsumeStatus();
}
}
return Export(std::move(module), options, graph);
}