- Eliminate all uses of passes that mark function visibility since the visibility

is now set correctly when importing.
- Update tf_saved_model dialect verification to verify that exported functions are
  marked public.
- Eliminate function_visibility.mlir test. This test fails after the
  tf_saved_model verification changes since its run tf.entry_function based
  visibility on a tf_saved_model MLIR module. Also, these passes will be removed.
- Fix TPURewritePass to mark the appropriate visibility on the serialized MLIR
  attached to tf._TPUCompileMlir op.

PiperOrigin-RevId: 317165278
Change-Id: I8e8f6de4b56e89c303815edc3b34bcf0a4e82d2d
This commit is contained in:
Rahul Joshi 2020-06-18 13:09:04 -07:00 committed by TensorFlower Gardener
parent 543d7c47a0
commit 41e7392f58
14 changed files with 59 additions and 360 deletions

View File

@ -212,9 +212,6 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
// Saved model pass to mark global tensors immutable. // Saved model pass to mark global tensors immutable.
pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
// Used to mark non-exported functions in saved model private.
pm.addPass(mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
// Op fusion pass. // Op fusion pass.
pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());

View File

@ -491,7 +491,6 @@ cc_library(
"transforms/graph_pruning.cc", "transforms/graph_pruning.cc",
"transforms/launch_to_device_attribute.cc", "transforms/launch_to_device_attribute.cc",
"transforms/layout_optimization.cc", "transforms/layout_optimization.cc",
"transforms/mark_function_visibility.cc",
"transforms/materialize_mlir_passthrough_op.cc", "transforms/materialize_mlir_passthrough_op.cc",
"transforms/optimize.cc", "transforms/optimize.cc",
"transforms/optimize_global_tensors.cc", "transforms/optimize_global_tensors.cc",

View File

@ -229,8 +229,20 @@ static LogicalResult VerifySavedModelModule(
} }
} }
for (auto func : module.getOps<FuncOp>()) { for (auto func : module.getOps<FuncOp>()) {
const bool is_exported = IsExported(func);
if (is_exported && func.getVisibility() != FuncOp::Visibility::Public) {
return func.emitError()
<< "exported function @" << func.getName() << " should be public";
}
if (!is_exported && func.getVisibility() == FuncOp::Visibility::Public) {
return func.emitError() << "non-exported function @" << func.getName()
<< " should be private";
}
if (HasAnyTfSavedModelArgAttr(func)) { if (HasAnyTfSavedModelArgAttr(func)) {
if (!IsExported(func)) { if (!is_exported) {
return func.emitError() return func.emitError()
<< "can only apply 'tf_saved_model' argument attributes " << "can only apply 'tf_saved_model' argument attributes "
"to exported functions"; "to exported functions";

View File

@ -1,47 +0,0 @@
// RUN: tf-opt -tf-saved-model-mark-func-visibility -split-input-file %s | FileCheck --check-prefix=SAVEDMODEL %s
// RUN: tf-opt -tf-mark-func-visibility -split-input-file -verify-diagnostics %s | FileCheck %s
module attributes {tf_saved_model.semantics} {
// SAVEDMODEL: func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]}
func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} {
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
return
}
// SAVEDMODEL: func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]}
func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} {
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
return
}
// SAVEDMODEL: func @func_not_exported() attributes {sym_visibility = "private"}
func @func_not_exported() {
return
}
}
// -----
module {
// CHECK: func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}}
func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} {
return %arg0 : tensor<1xi32>
}
// CHECK: func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {sym_visibility = "private"}
func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
%0 = "tf.AddV2"(%arg0, %arg1) {T = i32, device = ""} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
}
// -----
module {
// expected-error @+1 {{can't overwrite the visibility of function private_func_with_entry_spec with private visibility}}
func @private_func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}, sym_visibility = "private"} {
return %arg0 : tensor<1xi32>
}
}

View File

@ -1,96 +0,0 @@
// RUN: tf-opt -tf-saved-model-mark-func-visibility -symbol-dce -split-input-file %s | FileCheck %s
module attributes {tf_saved_model.semantics} {
// Test case: Unused function should be deleted.
// CHECK-NOT: func @unused
func @unused() {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Root calls child. Child should not be deleted.
// CHECK: func @root
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
"tf.some_call"() { callee = @child } : () -> ()
return
}
// CHECK: func @child
func @child() {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Don't crash if attribute that doesn't reference a func.
"tf.some_opaque_global_variable"() { sym_name = "some_global" } : () -> ()
func @root2() attributes {tf_saved_model.exported_names = ["root2"]} {
"tf.do_something_with_a_global"() { global = @some_global } : () -> ()
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Delete recursively dead cycle.
// CHECK-NOT: func @recursively_dead0
func @recursively_dead0() {
"tf.some_call"() { callee = @recursively_dead1 } : () -> ()
return
}
// CHECK-NOT: func @recursively_dead1
func @recursively_dead1() {
"tf.some_call"() { callee = @recursively_dead0 } : () -> ()
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Root calls child with a deeply nested symbol reference.
// Child should not be deleted.
// CHECK: func @root
func @root() attributes {tf_saved_model.exported_names = ["root"]} {
"tf.some_call"() {callee = {callee = {callee = @child}}} : () -> ()
return
}
// CHECK: func @child
func @child() {
return
}
}
// -----
// Test case: If the module doesn't have tf_saved_model semantics, then this
// pass shouldn't do anything.
module {
// CHECK: func @not_dead()
func @not_dead() {
return
}
}

View File

@ -64,7 +64,7 @@ module attributes {tf_saved_model.semantics} {
return return
} }
func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) { func @f_callee(%arg0: tensor<!tf.resource<tensor<f32>>>) attributes {sym_visibility = "private"} {
return return
} }
} }

View File

@ -40,7 +40,7 @@ module attributes {tf_saved_model.semantics} {
return %arg0 : tensor<f32> return %arg0 : tensor<f32>
} }
func @f() { func @f() attributes {sym_visibility = "private"} {
return return
} }

View File

@ -3,7 +3,7 @@
module attributes {tf_saved_model.semantics} { module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{unknown tf_saved_model dialect arg attribute 'tf_saved_model.not_a_real_arg_attr'}} // expected-error@+1 {{unknown tf_saved_model dialect arg attribute 'tf_saved_model.not_a_real_arg_attr'}}
func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) { func @f(%arg0: tensor<f32> {tf_saved_model.not_a_real_arg_attr = 1 : i32}) attributes {sym_visibility = "private"} {
return return
} }
@ -233,7 +233,7 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> () "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor<?xf32>, value = dense<1.> : tensor<1xf32> } : () -> ()
// expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}} // expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}}
func @f(%arg0: tensor<!tf.resource<tensor<?xf32>>> {tf_saved_model.bound_input = @v}) func @f(%arg0: tensor<!tf.resource<tensor<?xf32>>> {tf_saved_model.bound_input = @v})
-> (tensor<?xf32> {tf_saved_model.index_path = []}) { -> (tensor<?xf32> {tf_saved_model.index_path = []}) attributes {sym_visibility = "private"} {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<?xf32>>>) -> tensor<?xf32> %0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<?xf32>>>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
@ -273,7 +273,7 @@ module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have no output}} // expected-error@+1 {{the initializer function should have no output}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> () "tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> tensor<1xf32> { func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32> return %0 : tensor<1xf32>
} }
@ -286,8 +286,34 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() { initializer = @init } : () -> () "tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
// expected-error@+1 {{there must be no more than one session_initializer op}} // expected-error@+1 {{there must be no more than one session_initializer op}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> () "tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
func @init() -> tensor<1xf32> { func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32> return %0 : tensor<1xf32>
} }
} }
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{exported function @f should be public}}
func @f(
%arg0: tensor<f32> {tf.resource_name = "resource"}
) attributes { sym_visibility = "private", tf_saved_model.exported_names = ["foo.some_func"] } {
return
}
}
// -----
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{non-exported function @f should be private}}
func @f(
%arg0: tensor<f32> {tf.resource_name = "resource"}
) {
return
}
}

View File

@ -20,12 +20,12 @@ module attributes {tf_saved_model.semantics} {
return %val : tensor<f32> return %val : tensor<f32>
} }
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>) %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32> return %val : tensor<f32>
} }
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32> %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
return %val : tensor<f32> return %val : tensor<f32>
} }
@ -59,7 +59,7 @@ module attributes {tf_saved_model.semantics} {
return %val : tensor<f32> return %val : tensor<f32>
} }
func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32> %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor<f32>
return %val : tensor<f32> return %val : tensor<f32>
} }
@ -85,7 +85,7 @@ module attributes {tf_saved_model.semantics} {
return %val_2 : tensor<f32> return %val_2 : tensor<f32>
} }
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%cst_1 = constant dense<2.0> : tensor<f32> %cst_1 = constant dense<2.0> : tensor<f32>
return %cst_1 : tensor<f32> return %cst_1 : tensor<f32>
} }
@ -112,13 +112,13 @@ module attributes {tf_saved_model.semantics} {
} }
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>) %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32> return %val : tensor<f32>
} }
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32> %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> () "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32> return %c0 : tensor<f32>
@ -146,13 +146,13 @@ module attributes {tf_saved_model.semantics} {
} }
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>) %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32> return %val : tensor<f32>
} }
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32> %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> () "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32> return %c0 : tensor<f32>
@ -179,13 +179,13 @@ module attributes {tf_saved_model.semantics} {
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>) %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32> return %val : tensor<f32>
} }
// CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> // CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>) %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32> return %val : tensor<f32>
} }
@ -212,7 +212,7 @@ module attributes {tf_saved_model.semantics} {
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> { func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32> %c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> () "tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32> return %c0 : tensor<f32>

View File

@ -1,165 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#define DEBUG_TYPE "tf-shape-inference"
namespace mlir {
namespace {
LogicalResult MarkFunctionVisibility(
ModuleOp module, llvm::function_ref<bool(FuncOp func)> IsExternalVisible) {
LogicalResult result = success();
for (auto func : module.getOps<FuncOp>()) {
FuncOp::Visibility old_visibility = func.getVisibility();
FuncOp::Visibility visibility = IsExternalVisible(func)
? FuncOp::Visibility::Public
: FuncOp::Visibility::Private;
auto get_visibility_name = [](FuncOp::Visibility v) {
return v == FuncOp::Visibility::Public
? "public"
: v == FuncOp::Visibility::Private ? "private" : "nested";
};
if (old_visibility != SymbolTable::Visibility::Public &&
old_visibility != visibility) {
result = func.emitError()
<< "can't overwrite the visibility of function "
<< func.getName() << " with "
<< get_visibility_name(old_visibility) << " visibility";
}
LLVM_DEBUG(llvm::dbgs()
<< "function " << func.getName() << " has "
<< get_visibility_name(visibility) << " visibility \n");
func.setVisibility(visibility);
}
return result;
}
} // anonymous namespace
namespace TF {
LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
ModuleOp module) {
auto HasEntryFunctionSpecification = [](FuncOp func) -> bool {
auto attrs = func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
return attrs && !attrs.empty();
};
return MarkFunctionVisibility(module, HasEntryFunctionSpecification);
}
namespace {
struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass
: public PassWrapper<
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass,
OperationPass<ModuleOp>> {
void runOnOperation() override {
if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification(
getOperation()))) {
signalPassFailure();
}
}
};
} // namespace
static PassRegistration<
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>
pass("tf-mark-func-visibility",
"Use tf.entry_function to mark function visibility.");
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() {
return std::make_unique<
MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>();
}
// Marks the main function with public visibility, while other functions are
// marked with private visibility.
LogicalResult MarkOnlyMainFunctionWithPublicVisibility(ModuleOp module) {
for (auto func : module.getOps<FuncOp>()) {
if (func.getName() == "main") {
func.setVisibility(FuncOp::Visibility::Public);
} else {
func.setVisibility(FuncOp::Visibility::Private);
}
}
return success();
}
namespace {
struct MarkOnlyMainFunctionWithPublicVisibilityPass
: public PassWrapper<MarkOnlyMainFunctionWithPublicVisibilityPass,
OperationPass<ModuleOp>> {
void runOnOperation() override {
if (failed(MarkOnlyMainFunctionWithPublicVisibility(getOperation()))) {
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOnlyMainFunctionWithPublicVisibilityPass() {
return std::make_unique<MarkOnlyMainFunctionWithPublicVisibilityPass>();
}
} // namespace TF
namespace tf_saved_model {
static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage(
ModuleOp module) {
if (!tf_saved_model::HasTfSavedModelSemantics(module)) {
return success();
}
return MarkFunctionVisibility(module, tf_saved_model::IsExported);
}
namespace {
struct MarkFunctionVisibilityUsingSavedModelLinkagePass
: public PassWrapper<MarkFunctionVisibilityUsingSavedModelLinkagePass,
OperationPass<ModuleOp>> {
void runOnOperation() override {
if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) {
signalPassFailure();
}
}
};
} // namespace
static PassRegistration<MarkFunctionVisibilityUsingSavedModelLinkagePass> pass(
"tf-saved-model-mark-func-visibility",
"Use tf_saved_model linkage information to mark function visibility.");
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() {
return std::make_unique<MarkFunctionVisibilityUsingSavedModelLinkagePass>();
}
} // namespace tf_saved_model
} // namespace mlir

View File

@ -117,21 +117,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass();
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass(); CreateConvertReadonlyReferenceVariablesToResourceVariablesPass();
// Marks function visibility using tf.entry_function specification. That is,
// functions with tf.entry_function attributes are marked with public
// visibility while the other functions are marked with private visibility.
LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification(
ModuleOp module);
// Creates a pass that uses tf.entry_function specification to mark function
// visibility.
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass();
// Creates a pass that marks the main function with public visibility, while
// other functions are marked with private visibility.
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOnlyMainFunctionWithPublicVisibilityPass();
// Creates a simple device assignment pass on TF dialect for CoreRT use case. // Creates a simple device assignment pass on TF dialect for CoreRT use case.
std::unique_ptr<OperationPass<FuncOp>> CreateSimpleTFDeviceAssignmentPass( std::unique_ptr<OperationPass<FuncOp>> CreateSimpleTFDeviceAssignmentPass(
llvm::StringRef default_device); llvm::StringRef default_device);
@ -315,13 +300,6 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeGlobalTensorsPass();
// Creates a pass that freezes tf_saved_model.global_tensor ops. // Creates a pass that freezes tf_saved_model.global_tensor ops.
std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass(); std::unique_ptr<OperationPass<ModuleOp>> CreateFreezeGlobalTensorsPass();
// Creates a pass that uses tf_saved_model dialect linkage information
// to mark function visibility. That is, exported functions are marked with
// public visibility while the other functions are marked with private
// visibility.
std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass();
} // namespace tf_saved_model } // namespace tf_saved_model
} // namespace mlir } // namespace mlir

View File

@ -146,6 +146,9 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func,
// We can simply change name of TPU program's main function because there // We can simply change name of TPU program's main function because there
// should be no other reference to it. // should be no other reference to it.
clone.setName("main"); clone.setName("main");
clone.setVisibility(FuncOp::Visibility::Public);
} else {
clone.setVisibility(FuncOp::Visibility::Private);
} }
symbol_table.insert(clone); symbol_table.insert(clone);
} }

View File

@ -267,9 +267,6 @@ Status ConvertMLIRToXlaComputation(
const XlaCompiler::ShapeRepresentationFn shape_representation_fn, const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) { std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
mlir::PassManager tf2xla(module_op.getContext()); mlir::PassManager tf2xla(module_op.getContext());
// Mark main function as public, and other functions as private.
tf2xla.addPass(
mlir::TF::CreateMarkOnlyMainFunctionWithPublicVisibilityPass());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());

View File

@ -165,11 +165,6 @@ Status ConvertGraphDefToXlaViaMlir(
device_set.AddDevice(&device); device_set.AddDevice(&device);
AddDevicesToOp(*module, &device_set); AddDevicesToOp(*module, &device_set);
if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification(
*module))) {
return errors::Internal("Problem with mark function visibility");
}
TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline( TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline(
*module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true)); *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));