- Eliminate all uses of passes that mark function visibility since the visibility
is now set correctly when importing. - Update tf_saved_model dialect verification to verify that exported functions are marked public. - Eliminate function_visibility.mlir test. This test fails after the tf_saved_model verification changes since its run tf.entry_function based visibility on a tf_saved_model MLIR module. Also, these passes will be removed. - Fix TPURewritePass to mark the appropriate visibility on the serialized MLIR attached to tf._TPUCompileMlir op. PiperOrigin-RevId: 317165278 Change-Id: I8e8f6de4b56e89c303815edc3b34bcf0a4e82d2d
This commit is contained in:
parent
543d7c47a0
commit
41e7392f58
|
@ -212,9 +212,6 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
|||
|
||||
// Saved model pass to mark global tensors immutable.
|
||||
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
|
||||
// Used to mark non-exported functions in saved model private.
|
||||
pm.addPass(mlir::tf_saved_model::
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
|
||||
// Op fusion pass.
|
||||
pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
|
||||
|
|
|
@ -491,7 +491,6 @@ cc_library(
|
|||
"transforms/graph_pruning.cc",
|
||||
"transforms/launch_to_device_attribute.cc",
|
||||
"transforms/layout_optimization.cc",
|
||||
"transforms/mark_function_visibility.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
|
|
|
@ -229,8 +229,20 @@ static LogicalResult VerifySavedModelModule(
|
|||
}
|
||||
}
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
const bool is_exported = IsExported(func);
|
||||
|
||||
if (is_exported && func.getVisibility() != FuncOp::Visibility::Public) {
|
||||
return func.emitError()
|
||||
<< "exported function @" << func.getName() << " should be public";
|
||||
}
|
||||
|
||||
if (!is_exported && func.getVisibility() == FuncOp::Visibility::Public) {
|
||||
return func.emitError() << "non-exported function @" << func.getName()
|
||||
<< " should be private";
|
||||
}
|
||||
|
||||
if (HasAnyTfSavedModelArgAttr(func)) {
|
||||
if (!IsExported(func)) {
|
||||
if (!is_exported) {
|
||||
return func.emitError()
|
||||
<< "can only apply 'tf_saved_model' argument attributes "
|
||||
"to exported functions";
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
// RUN: tf-opt -tf-saved-model-mark-func-visibility -split-input-file %s | FileCheck --check-prefix=SAVEDMODEL %s
|
||||
// RUN: tf-opt -tf-mark-func-visibility -split-input-file -verify-diagnostics %s | FileCheck %s
|
||||
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
// SAVEDMODEL: func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]}
|
||||
func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} {
|
||||
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// SAVEDMODEL: func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]}
|
||||
func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} {
|
||||
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// SAVEDMODEL: func @func_not_exported() attributes {sym_visibility = "private"}
|
||||
func @func_not_exported() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// CHECK: func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}}
|
||||
func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} {
|
||||
return %arg0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK: func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {sym_visibility = "private"}
|
||||
func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) {T = i32, device = ""} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// expected-error @+1 {{can't overwrite the visibility of function private_func_with_entry_spec with private visibility}}
|
||||
func @private_func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}, sym_visibility = "private"} {
|
||||
return %arg0 : tensor<1xi32>
|
||||
}
|
||||
}
|
|
@ -1,96 +0,0 @@
|
|||
// RUN: tf-opt -tf-saved-model-mark-func-visibility -symbol-dce -split-input-file %s | FileCheck %s
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Unused function should be deleted.
|
||||
|
||||
// CHECK-NOT: func @unused
|
||||
func @unused() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Root calls child. Child should not be deleted.
|
||||
|
||||
// CHECK: func @root
|
||||
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
|
||||
"tf.some_call"() { callee = @child } : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @child
|
||||
func @child() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Don't crash if attribute that doesn't reference a func.
|
||||
|
||||
"tf.some_opaque_global_variable"() { sym_name = "some_global" } : () -> ()
|
||||
|
||||
func @root2() attributes {tf_saved_model.exported_names = ["root2"]} {
|
||||
"tf.do_something_with_a_global"() { global = @some_global } : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Delete recursively dead cycle.
|
||||
|
||||
// CHECK-NOT: func @recursively_dead0
|
||||
func @recursively_dead0() {
|
||||
"tf.some_call"() { callee = @recursively_dead1 } : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-NOT: func @recursively_dead1
|
||||
func @recursively_dead1() {
|
||||
"tf.some_call"() { callee = @recursively_dead0 } : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Root calls child with a deeply nested symbol reference.
|
||||
// Child should not be deleted.
|
||||
|
||||
// CHECK: func @root
|
||||
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
|
||||
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: func @child
|
||||
func @child() {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: If the module doesn't have tf_saved_model semantics, then this
|
||||
// pass shouldn't do anything.
|
||||
module {
|
||||
// CHECK: func @not_dead()
|
||||
func @not_dead() {
|
||||
return
|
||||
}
|
||||
}
|
|
@ -64,7 +64,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
return
|
||||
}
|
||||
|
||||
func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) {
|
||||
func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
return %arg0 : tensor<f32>
|
||||
}
|
||||
|
||||
func @f() {
|
||||
func @f() attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{unknown tf_saved_model dialect arg attribute 'tf_saved_model.not_a_real_arg_attr'}}
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) {
|
||||
func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) attributes {sym_visibility = "private"} {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -233,7 +233,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
|
||||
// expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}}
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<?xf32>>> {tf_saved_model.bound_input = @v})
|
||||
-> (tensor<?xf32> {tf_saved_model.index_path = []}) {
|
||||
-> (tensor<?xf32> {tf_saved_model.index_path = []}) attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<?xf32>>>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
@ -273,7 +273,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
|
||||
// expected-error@+1 {{the initializer function should have no output}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() -> tensor<1xf32> {
|
||||
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
@ -286,8 +286,34 @@ module attributes {tf_saved_model.semantics} {
|
|||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
// expected-error@+1 {{there must be no more than one session_initializer op}}
|
||||
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
|
||||
func @init() -> tensor<1xf32> {
|
||||
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{exported function @f should be public}}
|
||||
func @f(
|
||||
%arg0: tensor<f32> {tf.resource_name = "resource"}
|
||||
) attributes { sym_visibility = "private", tf_saved_model.exported_names = ["foo.some_func"] } {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// expected-error@+1 {{non-exported function @f should be private}}
|
||||
func @f(
|
||||
%arg0: tensor<f32> {tf.resource_name = "resource"}
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,12 +20,12 @@ module attributes {tf_saved_model.semantics} {
|
|||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
return %val_2 : tensor<f32>
|
||||
}
|
||||
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%cst_1 = constant dense<2.0> : tensor<f32>
|
||||
return %cst_1 : tensor<f32>
|
||||
}
|
||||
|
@ -112,13 +112,13 @@ module attributes {tf_saved_model.semantics} {
|
|||
}
|
||||
|
||||
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
|
@ -146,13 +146,13 @@ module attributes {tf_saved_model.semantics} {
|
|||
}
|
||||
|
||||
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
|
@ -179,13 +179,13 @@ module attributes {tf_saved_model.semantics} {
|
|||
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>)
|
||||
return %val : tensor<f32>
|
||||
}
|
||||
|
@ -212,7 +212,7 @@ module attributes {tf_saved_model.semantics} {
|
|||
|
||||
|
||||
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
|
||||
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
|
||||
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
|
||||
"tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
|
||||
return %c0 : tensor<f32>
|
||||
|
|
|
@ -1,165 +0,0 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-shape-inference"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
LogicalResult MarkFunctionVisibility(
|
||||
ModuleOp module, llvm::function_ref<bool(FuncOp func)> IsExternalVisible) {
|
||||
LogicalResult result = success();
|
||||
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
FuncOp::Visibility old_visibility = func.getVisibility();
|
||||
|
||||
FuncOp::Visibility visibility = IsExternalVisible(func)
|
||||
? FuncOp::Visibility::Public
|
||||
: FuncOp::Visibility::Private;
|
||||
|
||||
auto get_visibility_name = [](FuncOp::Visibility v) {
|
||||
return v == FuncOp::Visibility::Public
|
||||
? "public"
|
||||
: v == FuncOp::Visibility::Private ? "private" : "nested";
|
||||
};
|
||||
|
||||
if (old_visibility != SymbolTable::Visibility::Public &&
|
||||
old_visibility != visibility) {
|
||||
result = func.emitError()
|
||||
<< "can't overwrite the visibility of function "
|
||||
<< func.getName() << " with "
|
||||
<< get_visibility_name(old_visibility) << " visibility";
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "function " << func.getName() << " has "
|
||||
<< get_visibility_name(visibility) << " visibility \n");
|
||||
|
||||
func.setVisibility(visibility);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace TF {
|
||||
|
||||
LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
ModuleOp module) {
|
||||
auto HasEntryFunctionSpecification = [](FuncOp func) -> bool {
|
||||
auto attrs = func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
|
||||
return attrs && !attrs.empty();
|
||||
};
|
||||
return MarkFunctionVisibility(module, HasEntryFunctionSpecification);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass
|
||||
: public PassWrapper<
|
||||
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override {
|
||||
if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
getOperation()))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static PassRegistration<
|
||||
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>
|
||||
pass("tf-mark-func-visibility",
|
||||
"Use tf.entry_function to mark function visibility.");
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() {
|
||||
return std::make_unique<
|
||||
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>();
|
||||
}
|
||||
|
||||
// Marks the main function with public visibility, while other functions are
|
||||
// marked with private visibility.
|
||||
LogicalResult MarkOnlyMainFunctionWithPublicVisibility(ModuleOp module) {
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (func.getName() == "main") {
|
||||
func.setVisibility(FuncOp::Visibility::Public);
|
||||
} else {
|
||||
func.setVisibility(FuncOp::Visibility::Private);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct MarkOnlyMainFunctionWithPublicVisibilityPass
|
||||
: public PassWrapper<MarkOnlyMainFunctionWithPublicVisibilityPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override {
|
||||
if (failed(MarkOnlyMainFunctionWithPublicVisibility(getOperation()))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkOnlyMainFunctionWithPublicVisibilityPass() {
|
||||
return std::make_unique<MarkOnlyMainFunctionWithPublicVisibilityPass>();
|
||||
}
|
||||
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_saved_model {
|
||||
|
||||
static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage(
|
||||
ModuleOp module) {
|
||||
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
|
||||
return success();
|
||||
}
|
||||
return MarkFunctionVisibility(module, tf_saved_model::IsExported);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct MarkFunctionVisibilityUsingSavedModelLinkagePass
|
||||
: public PassWrapper<MarkFunctionVisibilityUsingSavedModelLinkagePass,
|
||||
OperationPass<ModuleOp>> {
|
||||
void runOnOperation() override {
|
||||
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static PassRegistration<MarkFunctionVisibilityUsingSavedModelLinkagePass> pass(
|
||||
"tf-saved-model-mark-func-visibility",
|
||||
"Use tf_saved_model linkage information to mark function visibility.");
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() {
|
||||
return std::make_unique<MarkFunctionVisibilityUsingSavedModelLinkagePass>();
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
|
||||
} // namespace mlir
|
|
@ -117,21 +117,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
|
||||
|
||||
// Marks function visibility using tf.entry_function specification. That is,
|
||||
// functions with tf.entry_function attributes are marked with public
|
||||
// visibility while the other functions are marked with private visibility.
|
||||
LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
ModuleOp module);
|
||||
// Creates a pass that uses tf.entry_function specification to mark function
|
||||
// visibility.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
|
||||
|
||||
// Creates a pass that marks the main function with public visibility, while
|
||||
// other functions are marked with private visibility.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkOnlyMainFunctionWithPublicVisibilityPass();
|
||||
|
||||
// Creates a simple device assignment pass on TF dialect for CoreRT use case.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
|
||||
llvm::StringRef default_device);
|
||||
|
@ -315,13 +300,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass();
|
|||
// Creates a pass that freezes tf_saved_model.global_tensor ops.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass();
|
||||
|
||||
// Creates a pass that uses tf_saved_model dialect linkage information
|
||||
// to mark function visibility. That is, exported functions are marked with
|
||||
// public visibility while the other functions are marked with private
|
||||
// visibility.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass();
|
||||
|
||||
} // namespace tf_saved_model
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -146,6 +146,9 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
|
|||
// We can simply change name of TPU program's main function because there
|
||||
// should be no other reference to it.
|
||||
clone.setName("main");
|
||||
clone.setVisibility(FuncOp::Visibility::Public);
|
||||
} else {
|
||||
clone.setVisibility(FuncOp::Visibility::Private);
|
||||
}
|
||||
symbol_table.insert(clone);
|
||||
}
|
||||
|
|
|
@ -267,9 +267,6 @@ Status ConvertMLIRToXlaComputation(
|
|||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
// Mark main function as public, and other functions as private.
|
||||
tf2xla.addPass(
|
||||
mlir::TF::CreateMarkOnlyMainFunctionWithPublicVisibilityPass());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());
|
||||
|
|
|
@ -165,11 +165,6 @@ Status ConvertGraphDefToXlaViaMlir(
|
|||
device_set.AddDevice(&device);
|
||||
AddDevicesToOp(*module, &device_set);
|
||||
|
||||
if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification(
|
||||
*module))) {
|
||||
return errors::Internal("Problem with mark function visibility");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
|
||||
*module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));
|
||||
|
||||
|
|
Loading…
Reference in New Issue