Make function names unchanged during Graphdef to MLIR translation.
Previously, a suffix (usually "0") is added to each function during Graphdef to MLIR import. With the change, a round-trip of Graph->MLIR->Graph or MLIR->Graph->MLIR shouldn't change any function names. Other notable changes: - Export functions now clear the function library passed in to avoid function name conflicts. - The V1 bridge doesn't like an optimization pipeline to change library function in-place, i.e. create different functions that have the same names as the functions in the original library. Therefore, the CL adds a "rename_private_functions" pass in the pipeline. - translation a library function to MLIR function can be done independently from converting the Graph, so this change also lifts the translation to the top level (`GraphDefImporter::Convert`) to simplify the implimentation. PiperOrigin-RevId: 329071572 Change-Id: Ib33bd492da53c9764f552ef42842efbdcd1a54d7
This commit is contained in:
parent
d98960155a
commit
a32c74ae8f
@ -768,6 +768,7 @@ cc_library(
|
||||
"transforms/promote_resources_to_args.cc",
|
||||
"transforms/readonly_references_to_resources.cc",
|
||||
"transforms/region_control_flow_to_functional.cc",
|
||||
"transforms/rename_private_functions.cc",
|
||||
"transforms/replicate_invariant_op_hoisting.cc",
|
||||
"transforms/replicate_to_island.cc",
|
||||
"transforms/resource_device_inference.cc",
|
||||
@ -1110,13 +1111,16 @@ cc_library(
|
||||
":convert_graphdef",
|
||||
":error_util",
|
||||
":mlir_roundtrip_flags",
|
||||
":tensorflow_passes",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -25,16 +25,17 @@ func @foo() {
|
||||
// CHECK: func @main()
|
||||
// CHECK: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]]
|
||||
|
||||
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
|
||||
|
||||
// CHECK: func @[[ELSE_FUNC:[A-Za-z0-9_]*]]
|
||||
// CHECK: "tf.Add"
|
||||
// CHECK: func @[[THEN_FUNC:[A-Za-z0-9_]*]]
|
||||
// CHECK: "tf.Mul"
|
||||
|
||||
// In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch.
|
||||
// CHECK: func @[[FUNCTIONALIZE_FUNC]]
|
||||
// CHECK: "tf.If"
|
||||
// CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]]
|
||||
// CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]]
|
||||
// CHECK-SAME: else_branch = @[[ELSE_FUNC]]
|
||||
// CHECK-SAME: then_branch = @[[THEN_FUNC]]
|
||||
|
||||
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
|
||||
|
||||
// CHECK: func @[[ELSE_FUNC]]
|
||||
// CHECK: "tf.Add"
|
||||
// CHECK: func @[[THEN_FUNC]]
|
||||
// CHECK: "tf.Mul"
|
||||
|
@ -54,5 +54,5 @@ versions {
|
||||
# the names are matching between the function definition and the uses / call
|
||||
# site (a numerical suffix may be appended).
|
||||
|
||||
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo0}
|
||||
# CHECK: func @foo0
|
||||
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo}
|
||||
# CHECK: func @foo
|
||||
|
@ -68,4 +68,4 @@ library {
|
||||
}
|
||||
|
||||
# CHECK: func @main
|
||||
# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @test_func_name0}
|
||||
# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @test_func_name}
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Verify that the _input_shapes attribute of the FunctionDef is respected.
|
||||
# This also checks that the output type is correctly inferred based on
|
||||
# that.
|
||||
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
|
||||
#CHECK: func @identity_function(%arg0: tensor<i32>) -> tensor<i32>
|
||||
|
||||
node {
|
||||
name: "Placeholder"
|
||||
|
@ -121,8 +121,8 @@ versions {
|
||||
# Verify that functions from the library are properly imported.
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo110}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo111}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo1}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo11}
|
||||
|
||||
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func @foo1() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func @foo11() attributes {sym_visibility = "private"}
|
||||
|
@ -88,10 +88,10 @@ library {
|
||||
# CHECK: tf_executor.graph
|
||||
# CHECK: "tf.VarHandleOp"()
|
||||
# CHECK: "tf.LegacyCall"
|
||||
# CHECK-SAME: {_disable_call_shape_inference = true, device = "", f = @test_func_name0}
|
||||
# CHECK-SAME: {_disable_call_shape_inference = true, device = "", f = @test_func_name}
|
||||
# CHECK: tf_executor.fetch
|
||||
# CHECK: return
|
||||
# CHECK: func @test_func_name0
|
||||
# CHECK: func @test_func_name
|
||||
# CHECK-SAME: tf._resource_arg_unique_id = 0
|
||||
# CHECK-SAME: tf._resource_arg_unique_id = 0
|
||||
# CHECK: tf_executor.graph
|
||||
|
@ -4,7 +4,7 @@
|
||||
# links the function and its gradient. In MLIR a TF ops gradient function is
|
||||
# added to its list of function attributes.
|
||||
|
||||
# CHECK: func @foo0(
|
||||
# CHECK: func @foo(
|
||||
# CHECK: tf.gradient = @foo_grad
|
||||
|
||||
node {
|
||||
|
@ -54,10 +54,11 @@ versions {
|
||||
# Verify that functions from the library are properly imported.
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar}
|
||||
|
||||
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
|
||||
# CHECK-LABEL: func @bar() attributes {sym_visibility = "private"}
|
||||
|
||||
# CHECK-LABEL: func @foo() attributes {sym_visibility = "private"}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar}
|
||||
|
||||
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}
|
||||
|
@ -0,0 +1,53 @@
|
||||
// RUN: tf-opt %s -allow-unregistered-dialect -split-input-file -tf-rename-private-functions | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @simple
|
||||
func @simple() {
|
||||
|
||||
// CHECK: "my.call"() {func = @[[NEW_FUNC_NAME:.+]]}
|
||||
"my.call"() {func = @my_func} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @my_func()
|
||||
// CHECK: func @[[NEW_FUNC_NAME]]()
|
||||
func @my_func() -> () attributes {sym_visibility = "private"}
|
||||
|
||||
// -----
|
||||
|
||||
// A stress test case to test uniquification logic
|
||||
|
||||
// CHECK-LABEL: @test_uniquification
|
||||
func @test_uniquification() {
|
||||
// CHECK: "my.call"() {func = @[[NEW_FUNC_NAME_0:.+]]}
|
||||
"my.call"() {func = @my_func} : () -> ()
|
||||
// CHECK: "my.call"() {func = @[[NEW_FUNC_NAME_1:.+]]}
|
||||
"my.call"() {func = @my_func0} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @my_func()
|
||||
// CHECK-NOT: func @my_func0()
|
||||
|
||||
// CHECK: func @[[NEW_FUNC_NAME_0]]()
|
||||
func @my_func() -> () attributes {sym_visibility = "private"}
|
||||
// CHECK: func @[[NEW_FUNC_NAME_1]]()
|
||||
func @my_func0() -> () attributes {sym_visibility = "private"}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Test for SymbolRefArrayAttr
|
||||
|
||||
// CHECK-LABEL: @test_case_op
|
||||
func @test_case_op(%arg0: tensor<i32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_one, @branch_two], is_stateless = false} : (tensor<i32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: "tf.Case"(%arg0, %arg1) {branches = [@[[NEW_FUNC_NAME_1:.+]], @[[NEW_FUNC_NAME_2:.+]]]
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
// CHECK-NOT: func @branch_one()
|
||||
// CHECK-NOT: func @branch_two()
|
||||
|
||||
// CHECK: func @[[NEW_FUNC_NAME_1]]
|
||||
func @branch_one(tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
|
||||
// CHECK: func @[[NEW_FUNC_NAME_2]]
|
||||
func @branch_two(tensor<2xf32>) -> tensor<2xf32> attributes {sym_visibility = "private"}
|
@ -122,6 +122,10 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
}
|
||||
|
||||
void CreateTPUBridgePipelineV1(OpPassManager &pm) {
|
||||
// Function library in the V1 bridge is likely to contain many unused
|
||||
// functions, so we remove them early to speed up the rest of the pipeline.
|
||||
pm.addPass(createSymbolDCEPass());
|
||||
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
// For V1 compatibility, we process a module where the graph does not have
|
||||
// feeds and fetched. We extract first the TPU computation in a submodule,
|
||||
@ -132,6 +136,19 @@ void CreateTPUBridgePipelineV1(OpPassManager &pm) {
|
||||
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandOutliningPass());
|
||||
OpPassManager &nested_module = pm.nest<ModuleOp>();
|
||||
CreateTPUBridgePipeline(nested_module);
|
||||
|
||||
// The TF runtime using the V1 bridge doesn't like an optimization pipeline to
|
||||
// change library functions in-place, i.e. create different functions that
|
||||
// have the same names as the functions in the original function library. Some
|
||||
// of this constraint come from the fact that Session can extend its function
|
||||
// library with the output function library of the bridge and equality checks
|
||||
// of FunctionDef's are based on exact contents which is not guaranteed by the
|
||||
// TF importer/exporter nor by the V1 bridge.
|
||||
//
|
||||
// Therefore, we rename all these function to new names to avoid any failures
|
||||
// in Session::Extend.
|
||||
pm.addPass(TF::CreateRenamePrivateFunctionPass());
|
||||
|
||||
pm.addPass(tf_executor::CreateTFExecutorTPUV1IslandInliningPass());
|
||||
}
|
||||
|
||||
|
@ -173,6 +173,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
// Creates function pass to replace InitializeTableFromTextFileV2Ops with
|
||||
// LookupTableImportV2Op ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateInitTextFileToImportPass();
|
||||
|
||||
// Creates a module pass that renames all the private functions to new
|
||||
// names that don't exist in the original module.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateRenamePrivateFunctionPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
|
@ -0,0 +1,143 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
// A helper class that generates name strings that are both uniques among
|
||||
// a pre-defined set of existing strings and among the new strings it generates.
|
||||
class NameUniquifier : public tensorflow::OpOrArgNameMapper {
|
||||
public:
|
||||
explicit NameUniquifier(const llvm::StringSet<> &existing_names)
|
||||
: existing_names_(existing_names) {}
|
||||
|
||||
private:
|
||||
bool IsUnique(llvm::StringRef name) override {
|
||||
return !existing_names_.contains(name);
|
||||
}
|
||||
|
||||
std::string GetName(tensorflow::OpOrVal op_or_val) override {
|
||||
llvm_unreachable("This method shouldn't be used.");
|
||||
return "";
|
||||
}
|
||||
|
||||
const llvm::StringSet<> &existing_names_;
|
||||
};
|
||||
|
||||
// Returns an updated SymbolRefAttr according to `symbol_renaming_map`.
|
||||
// If the symbol name is not in the map, then the function returns the `old`
|
||||
// SymbolRefAttr.
|
||||
SymbolRefAttr GetUpdatedSymbolRefAttr(
|
||||
SymbolRefAttr old, const llvm::StringMap<StringRef> &symbol_renaming_map) {
|
||||
auto it = symbol_renaming_map.find(old.getRootReference());
|
||||
if (it == symbol_renaming_map.end()) {
|
||||
return old;
|
||||
}
|
||||
|
||||
StringRef new_symbol_name = it->second;
|
||||
return SymbolRefAttr::get(new_symbol_name, old.getNestedReferences(),
|
||||
old.getContext());
|
||||
}
|
||||
|
||||
// A pass that renames all the private functions to new names that don't exist
|
||||
// in the original module.
|
||||
struct RenamePrivateFunctionPass
|
||||
: public PassWrapper<RenamePrivateFunctionPass, OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void RenamePrivateFunctionPass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
// Get all old function names
|
||||
llvm::StringSet<> old_private_func_names;
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
old_private_func_names.insert(func.getName());
|
||||
}
|
||||
|
||||
// Update private function names
|
||||
NameUniquifier name_uniquifier(old_private_func_names);
|
||||
llvm::StringMap<StringRef> func_name_map;
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (func.isPrivate()) {
|
||||
StringRef old_name = func.getName();
|
||||
StringRef new_name = name_uniquifier.GetUniqueName(old_name);
|
||||
func.setName(new_name);
|
||||
func_name_map.insert(std::make_pair(old_name, new_name));
|
||||
}
|
||||
}
|
||||
|
||||
// Update any SymbolRefAttr
|
||||
module.walk([&func_name_map](Operation *op) {
|
||||
for (NamedAttribute p : op->getAttrs()) {
|
||||
Identifier id = p.first;
|
||||
Attribute attr = p.second;
|
||||
if (auto symbol_ref = attr.dyn_cast<SymbolRefAttr>()) {
|
||||
SymbolRefAttr new_symbol_ref =
|
||||
GetUpdatedSymbolRefAttr(symbol_ref, func_name_map);
|
||||
if (new_symbol_ref != symbol_ref) {
|
||||
op->setAttr(id, new_symbol_ref);
|
||||
}
|
||||
} else if (auto array_attr = attr.dyn_cast<ArrayAttr>()) {
|
||||
// Update any SymbolRefAttr in the ArrayAttr
|
||||
SmallVector<Attribute, 4> new_array;
|
||||
new_array.reserve(array_attr.size());
|
||||
for (Attribute attr : array_attr.getValue()) {
|
||||
if (auto symbol_ref = attr.dyn_cast<SymbolRefAttr>()) {
|
||||
SymbolRefAttr new_symbol_ref =
|
||||
GetUpdatedSymbolRefAttr(symbol_ref, func_name_map);
|
||||
new_array.push_back(new_symbol_ref);
|
||||
} else {
|
||||
new_array.push_back(attr);
|
||||
}
|
||||
}
|
||||
auto new_array_attr = ArrayAttr::get(new_array, op->getContext());
|
||||
if (new_array_attr != array_attr) {
|
||||
op->setAttr(id, new_array_attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
PassRegistration<RenamePrivateFunctionPass> tpu_pass(
|
||||
"tf-rename-private-functions",
|
||||
"Renames all the private functions to new names");
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateRenamePrivateFunctionPass() {
|
||||
return std::make_unique<RenamePrivateFunctionPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
} // namespace mlir
|
@ -717,6 +717,8 @@ Status Exporter::Convert(mlir::ModuleOp module,
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*graph, Exporter::Convert(configs, tf_dialect, entry_func.value(), &flib,
|
||||
control_ret_nodes));
|
||||
|
||||
flib_def->Clear();
|
||||
for (auto& func_def : flib.function()) {
|
||||
TF_RETURN_IF_ERROR(flib_def->AddFunctionDef(func_def));
|
||||
}
|
||||
|
@ -38,8 +38,9 @@ StatusOr<std::unique_ptr<GraphDef>> ConvertMlirToGraphdef(
|
||||
|
||||
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
|
||||
// The "main" function of the module is stored in the graph and the rest of
|
||||
// functions are stored in the library. Control ret nodes are stored separately
|
||||
// in `control_ret_nodes`.
|
||||
// functions are stored in the library. Note that existing functions in the
|
||||
// library will be deleted. Control ret nodes are stored separately in
|
||||
// `control_ret_nodes`.
|
||||
stream_executor::port::Status ConvertMlirToGraph(
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
|
||||
@ -47,7 +48,8 @@ stream_executor::port::Status ConvertMlirToGraph(
|
||||
|
||||
// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition.
|
||||
// The "main" function of the module is stored in the graph and the rest of
|
||||
// functions are stored in the library.
|
||||
// functions are stored in the library. Note that existing functions in the
|
||||
// library will be deleted.
|
||||
stream_executor::port::Status ConvertMlirToGraph(
|
||||
mlir::ModuleOp module, const GraphExportConfig& configs,
|
||||
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def);
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@ -149,34 +150,6 @@ void LoadImporterDialects(mlir::MLIRContext& context) {
|
||||
registry.loadAll(&context);
|
||||
}
|
||||
|
||||
// This class is used to generate new MLIR function name strings that are both
|
||||
// unique in the TF function library `flib_` and unique among the name strings
|
||||
// generated by the class object during its lifetime.
|
||||
//
|
||||
// In theory, this class is not necessary because we should simply take
|
||||
// the TF function name and use it as MLIR function name. However, for some
|
||||
// unknown reasons (callout for investigation in b/142268695), keeping the
|
||||
// function names unchanged in an MLIR roundtrip causes test failures.
|
||||
// TODO(b/142268695) Re-evaluate whether we need this class v.s. directly using
|
||||
// and TF function name as MLIR function name after b/142268695 is root caused.
|
||||
class NameUniquifier : public OpOrArgNameMapper {
|
||||
public:
|
||||
explicit NameUniquifier(const FunctionLibraryDefinition& flib)
|
||||
: flib_(flib) {}
|
||||
|
||||
private:
|
||||
bool IsUnique(llvm::StringRef name) override {
|
||||
return !flib_.Contains(std::string(name));
|
||||
}
|
||||
|
||||
std::string GetName(OpOrVal op_or_val) override {
|
||||
DCHECK(false) << "Unimplemented";
|
||||
return "";
|
||||
}
|
||||
|
||||
const FunctionLibraryDefinition& flib_;
|
||||
};
|
||||
|
||||
Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def,
|
||||
bool restrict_functionalization_to_tpu_nodes) {
|
||||
// If `restrict_functionalization_to_tpu_nodes` is true let filter function
|
||||
@ -202,18 +175,14 @@ class ImporterBase {
|
||||
explicit ImporterBase(
|
||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||
const GraphImportConfig& specs, mlir::ModuleOp module,
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
|
||||
NameUniquifier* function_name_uniquifier,
|
||||
llvm::StringRef function_name_for_debug_info = "")
|
||||
: builder_(module.getContext()),
|
||||
module_(module),
|
||||
context_(module.getContext()),
|
||||
tf_name_to_mlir_name_(tf_name_to_mlir_name),
|
||||
graph_flib_(flib),
|
||||
specs_(specs),
|
||||
debug_info_(debug_info),
|
||||
function_name_for_debug_info_(function_name_for_debug_info),
|
||||
function_name_uniquifier_(function_name_uniquifier),
|
||||
error_handler_(module.getContext()) {}
|
||||
|
||||
// Returns the inferred function signature of the given function body. Input
|
||||
@ -429,7 +398,6 @@ class ImporterBase {
|
||||
mlir::OpBuilder builder_;
|
||||
mlir::ModuleOp module_;
|
||||
mlir::MLIRContext* context_;
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name_;
|
||||
const FunctionLibraryDefinition& graph_flib_;
|
||||
const GraphImportConfig& specs_;
|
||||
const GraphDebugInfo& debug_info_;
|
||||
@ -439,7 +407,6 @@ class ImporterBase {
|
||||
// The shape_refinner_ will be nullptr if shape inference on import is
|
||||
// not enabled.
|
||||
std::unique_ptr<ShapeRefiner> shape_refiner_ = nullptr;
|
||||
NameUniquifier* function_name_uniquifier_;
|
||||
mlir::StatusScopedDiagnosticHandler error_handler_;
|
||||
|
||||
protected:
|
||||
@ -1129,10 +1096,7 @@ Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name,
|
||||
|
||||
StatusOr<mlir::FlatSymbolRefAttr> ImporterBase::ConvertFunctionCallName(
|
||||
const std::string& func_name) {
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(func_name));
|
||||
auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name];
|
||||
auto func = module_.lookupSymbol<mlir::FuncOp>(mlir_func_name);
|
||||
return builder_.getSymbolRefAttr(func);
|
||||
return builder_.getSymbolRefAttr(func_name);
|
||||
}
|
||||
|
||||
StatusOr<mlir::Attribute> ImporterBase::ConvertAttributeValue(
|
||||
@ -1225,16 +1189,6 @@ void ImporterBase::GetArgsAndRetsFromFunctionBody(
|
||||
}
|
||||
|
||||
Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
// If the library function has been converted already, nothing needs to be
|
||||
// done.
|
||||
if (tf_name_to_mlir_name_->find(std::string(func_name)) !=
|
||||
tf_name_to_mlir_name_->end())
|
||||
return Status::OK();
|
||||
|
||||
std::string mlir_func_name(
|
||||
function_name_uniquifier_->GetUniqueName(func_name));
|
||||
(*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name;
|
||||
|
||||
const auto& func_lib = graph_flib_;
|
||||
const auto* func_def = func_lib.Find(std::string(func_name));
|
||||
if (func_def == nullptr) {
|
||||
@ -1272,10 +1226,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
// list of this function.
|
||||
auto grad_func_name = func_lib.FindGradient(std::string(func_name));
|
||||
if (!grad_func_name.empty()) {
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name));
|
||||
auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name];
|
||||
auto grad_func = module_.lookupSymbol<mlir::FuncOp>(mlir_grad_func_name);
|
||||
auto gradient_attr = builder_.getSymbolRefAttr(grad_func);
|
||||
auto gradient_attr = builder_.getSymbolRefAttr(grad_func_name);
|
||||
auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName();
|
||||
attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr));
|
||||
}
|
||||
@ -1305,7 +1256,6 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
}
|
||||
|
||||
ImporterBase child_importer(graph_flib_, debug_info_, specs, module_,
|
||||
tf_name_to_mlir_name_, function_name_uniquifier_,
|
||||
func_name);
|
||||
TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph));
|
||||
|
||||
@ -1319,7 +1269,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
&control_ret_nodes);
|
||||
|
||||
TF_RETURN_IF_ERROR(child_importer.Convert(
|
||||
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
|
||||
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
|
||||
llvm::makeArrayRef(attributes.begin(), attributes.end())));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1803,14 +1753,9 @@ Status ImporterBase::ConvertNode(const Node& node) {
|
||||
|
||||
// If it is a custom OP, its definition should be found in the library. We
|
||||
// create the MLIR function and insert it to the module if it doesn't exist.
|
||||
std::string node_type_name = node.type_string();
|
||||
const std::string& node_type_name = node.type_string();
|
||||
const auto* func_def = graph_flib_.Find(node_type_name);
|
||||
bool convert_to_legacy_call = false;
|
||||
if (func_def) {
|
||||
TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name));
|
||||
node_type_name = (*tf_name_to_mlir_name_)[node_type_name];
|
||||
convert_to_legacy_call = true;
|
||||
}
|
||||
const bool convert_to_legacy_call = func_def != nullptr;
|
||||
|
||||
auto get_full_op_name = [&](const std::string& op_name) {
|
||||
const char* kTfPrefix = "tf.";
|
||||
@ -2097,21 +2042,21 @@ StatusOr<mlir::FunctionType> ImporterBase::InferLibFunctionType(
|
||||
// in the module.
|
||||
class GraphDefImporter : public ImporterBase {
|
||||
public:
|
||||
// Main entry point: converts the given graph to an MLIR Module.
|
||||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
// Main entry point: converts the given `graph` and library functions to an
|
||||
// MLIR Module. `graph` is translated to a public function with `func_name`,
|
||||
// while other functions in `flib_def` become private functions in the module.
|
||||
static StatusOr<mlir::OwningModuleRef> ConvertGraphAndLibraryFunctions(
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
|
||||
llvm::StringRef func_name);
|
||||
|
||||
private:
|
||||
explicit GraphDefImporter(
|
||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||
const GraphImportConfig& specs, mlir::ModuleOp module,
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
|
||||
NameUniquifier* function_name_uniquifier)
|
||||
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
|
||||
function_name_uniquifier) {}
|
||||
explicit GraphDefImporter(const FunctionLibraryDefinition& flib,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const GraphImportConfig& specs,
|
||||
mlir::ModuleOp module)
|
||||
: ImporterBase(flib, debug_info, specs, module) {}
|
||||
|
||||
// Returns the function signature of the main function of converted MLIR
|
||||
// module, the input nodes and output nodes. The type and shape information
|
||||
@ -2140,18 +2085,16 @@ class GraphDefImporter : public ImporterBase {
|
||||
absl::InlinedVector<Node*, 4>* control_ret_nodes);
|
||||
};
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
StatusOr<mlir::OwningModuleRef>
|
||||
GraphDefImporter::ConvertGraphAndLibraryFunctions(
|
||||
mlir::MLIRContext* context, const Graph& graph,
|
||||
const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def,
|
||||
const GraphImportConfig& specs, llvm::StringRef func_name) {
|
||||
LoadImporterDialects(*context);
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
NameUniquifier function_name_uniquifier(flib_def);
|
||||
|
||||
GraphDefImporter importer(flib_def, debug_info, specs, module.get(),
|
||||
&tf_name_to_mlir_name, &function_name_uniquifier);
|
||||
GraphDefImporter importer(flib_def, debug_info, specs, module.get());
|
||||
|
||||
TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
|
||||
|
||||
@ -2238,6 +2181,12 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
|
||||
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
|
||||
|
||||
std::vector<std::string> lib_func_names = flib_def.ListFunctionNames();
|
||||
std::sort(lib_func_names.begin(), lib_func_names.end());
|
||||
for (const std::string& fn_name : lib_func_names) {
|
||||
TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name));
|
||||
}
|
||||
|
||||
// Mark main function public, others private.
|
||||
for (auto function : module.get().getOps<mlir::FuncOp>()) {
|
||||
auto visibility = function.getName() == func_name
|
||||
@ -2459,13 +2408,11 @@ class SavedModelObjectGraphImporter : public ImporterBase {
|
||||
mlir::MLIRContext* context, bool add_default_attributes);
|
||||
|
||||
private:
|
||||
explicit SavedModelObjectGraphImporter(
|
||||
const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info,
|
||||
const GraphImportConfig& specs, mlir::ModuleOp module,
|
||||
std::unordered_map<std::string, std::string>* tf_name_to_mlir_name,
|
||||
NameUniquifier* function_name_uniquifier)
|
||||
: ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name,
|
||||
function_name_uniquifier) {}
|
||||
explicit SavedModelObjectGraphImporter(const FunctionLibraryDefinition& flib,
|
||||
const GraphDebugInfo& debug_info,
|
||||
const GraphImportConfig& specs,
|
||||
mlir::ModuleOp module)
|
||||
: ImporterBase(flib, debug_info, specs, module) {}
|
||||
};
|
||||
|
||||
// Determines the names used to reference objects in the SavedObjectGraph.
|
||||
@ -3008,7 +2955,6 @@ void SortSavedModelModule(mlir::ModuleOp module) {
|
||||
Status CreateSavedModelIR(
|
||||
const ObjectNames& object_names, mlir::ModuleOp module,
|
||||
const SavedObjectGraph& object_graph,
|
||||
const std::unordered_map<std::string, std::string>& tf_name_to_mlir_name,
|
||||
SavedModelV2Bundle* saved_model) {
|
||||
mlir::OpBuilder builder(module.getBodyRegion());
|
||||
mlir::SymbolTable symbol_table(module);
|
||||
@ -3043,8 +2989,8 @@ Status CreateSavedModelIR(
|
||||
"While importing SavedModel function '" +
|
||||
object_names.GetExportedNames(node_id)[0].str() + "': ";
|
||||
const SavedFunction& function = object.function();
|
||||
auto orig_func = symbol_table.lookup<mlir::FuncOp>(
|
||||
tf_name_to_mlir_name.find(function.concrete_functions(0))->second);
|
||||
auto orig_func =
|
||||
symbol_table.lookup<mlir::FuncOp>(function.concrete_functions(0));
|
||||
mlir::FuncOp func = orig_func;
|
||||
// If there are potentially references to this func from within the
|
||||
// module, create a wrapper around it and decorate the wrapper with the
|
||||
@ -3210,7 +3156,6 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||
specs.prune_unused_nodes = true;
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
std::unordered_map<std::string, std::string> tf_name_to_mlir_name;
|
||||
|
||||
const auto& graphdef = saved_model->meta_graph_def().graph_def();
|
||||
PopulateTfVersions(module.get(), graphdef.versions());
|
||||
@ -3228,10 +3173,8 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph));
|
||||
|
||||
NameUniquifier function_name_uniquifier(graph.flib_def());
|
||||
SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs,
|
||||
module.get(), &tf_name_to_mlir_name,
|
||||
&function_name_uniquifier);
|
||||
module.get());
|
||||
|
||||
TF_RETURN_IF_ERROR(importer.PrepareConvert(graph));
|
||||
|
||||
@ -3265,8 +3208,7 @@ StatusOr<mlir::OwningModuleRef> SavedModelObjectGraphImporter::Convert(
|
||||
|
||||
// Construct the SavedModel IR.
|
||||
TF_RETURN_IF_ERROR(CreateSavedModelIR(object_names, module.get(),
|
||||
object_graph, tf_name_to_mlir_name,
|
||||
saved_model));
|
||||
object_graph, saved_model));
|
||||
assert(mlir::succeeded(mlir::verify(module.get())));
|
||||
|
||||
return module;
|
||||
@ -3510,8 +3452,9 @@ StatusOr<mlir::OwningModuleRef> SavedModelSignatureDefImporter::ConvertGraph(
|
||||
specs.control_outputs = control_outputs;
|
||||
|
||||
// Convert sub-graph to MLIR module.true
|
||||
return GraphDefImporter::Convert(module_->getContext(), graph(), debug_info(),
|
||||
graph().flib_def(), specs, name);
|
||||
return GraphDefImporter::ConvertGraphAndLibraryFunctions(
|
||||
module_->getContext(), graph(), debug_info(), graph().flib_def(), specs,
|
||||
name);
|
||||
}
|
||||
|
||||
Status SavedModelSignatureDefImporter::ConvertSignature(
|
||||
@ -3636,7 +3579,8 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
const_cast<FunctionLibraryDefinition*>(&flib_def),
|
||||
specs.restrict_functionalization_to_tpu_nodes));
|
||||
}
|
||||
return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs,
|
||||
return GraphDefImporter::ConvertGraphAndLibraryFunctions(
|
||||
context, graph, debug_info, flib_def, specs,
|
||||
/*func_name=*/"main");
|
||||
}
|
||||
|
||||
@ -3651,11 +3595,16 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(),
|
||||
&flib_def, &fbody));
|
||||
|
||||
// Create a copy of function library that doesn't contain the function `name`
|
||||
// because it duplicates the `fbody` graph.
|
||||
FunctionLibraryDefinition flib_def_copy = flib_def;
|
||||
TF_RETURN_IF_ERROR(flib_def_copy.RemoveFunction(name.str()));
|
||||
|
||||
tensorflow::GraphDebugInfo dummy_debug_info;
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.graph_as_function = true;
|
||||
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
|
||||
flib_def, specs, name);
|
||||
return GraphDefImporter::ConvertGraphAndLibraryFunctions(
|
||||
context, *fbody->graph, dummy_debug_info, flib_def_copy, specs, name);
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
|
@ -18,6 +18,10 @@ limitations under the License.
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/Passes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
@ -52,12 +56,36 @@ static Status Export(mlir::OwningModuleRef module,
|
||||
const GraphOptimizationPassOptions& options,
|
||||
std::unique_ptr<Graph>* graph) {
|
||||
GraphExportConfig confs;
|
||||
return ConvertMlirToGraph(*module, confs, graph, options.flib_def);
|
||||
FunctionLibraryDefinition exported_function_library(OpRegistry::Global(), {});
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertMlirToGraph(*module, confs, graph, &exported_function_library));
|
||||
return options.flib_def->AddLibrary(exported_function_library);
|
||||
}
|
||||
|
||||
static Status Roundtrip(const GraphOptimizationPassOptions& options,
|
||||
std::unique_ptr<Graph>* graph, MLIRContext* context) {
|
||||
TF_ASSIGN_OR_RETURN(auto module, Import(options, **graph, context));
|
||||
|
||||
{
|
||||
// The TF runtime doesn't like an optimization pipeline
|
||||
// to change library functions in-place, i.e. create different functions
|
||||
// that have the same names as the functions in the original function
|
||||
// library. Some of this constraint come from the fact that Session can
|
||||
// extend its function library with the output function library of the
|
||||
// bridge and equality checks of FunctionDef's are based on exact contents
|
||||
// which is not guaranteed by the TF importer/exporter.
|
||||
//
|
||||
// Therefore, we rename all these function to new names to avoid any
|
||||
// failures in Session::Extend.
|
||||
mlir::PassManager pm(context);
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
pm.addPass(mlir::TF::CreateRenamePrivateFunctionPass());
|
||||
|
||||
mlir::StatusScopedDiagnosticHandler status_handler(context);
|
||||
if (mlir::failed(pm.run(module.get()))) {
|
||||
return status_handler.ConsumeStatus();
|
||||
}
|
||||
}
|
||||
return Export(std::move(module), options, graph);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user