diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 589515d6246..008098f62ba 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -212,9 +212,6 @@ void CreateTFLStandardPipeline(OpPassManager& pm, // Saved model pass to mark global tensors immutable. pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); - // Used to mark non-exported functions in saved model private. - pm.addPass(mlir::tf_saved_model:: - CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); // Op fusion pass. pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 17ed0e36a28..54e57512c32 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -491,7 +491,6 @@ cc_library( "transforms/graph_pruning.cc", "transforms/launch_to_device_attribute.cc", "transforms/layout_optimization.cc", - "transforms/mark_function_visibility.cc", "transforms/materialize_mlir_passthrough_op.cc", "transforms/optimize.cc", "transforms/optimize_global_tensors.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 6af70158e14..d59532fef65 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -229,8 +229,20 @@ static LogicalResult VerifySavedModelModule( } } for (auto func : module.getOps()) { + const bool is_exported = IsExported(func); + + if (is_exported && func.getVisibility() != FuncOp::Visibility::Public) { + return func.emitError() + << "exported function @" << func.getName() << " should be public"; + } + + if (!is_exported && func.getVisibility() == FuncOp::Visibility::Public) { + return func.emitError() << "non-exported function @" << func.getName() + << " should be private"; + } + if (HasAnyTfSavedModelArgAttr(func)) { - if (!IsExported(func)) { + if (!is_exported) { return func.emitError() << "can only apply 'tf_saved_model' argument attributes " "to exported functions"; diff --git a/tensorflow/compiler/mlir/tensorflow/tests/function_visibility.mlir b/tensorflow/compiler/mlir/tensorflow/tests/function_visibility.mlir deleted file mode 100644 index 55af3cffde3..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/function_visibility.mlir +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: tf-opt -tf-saved-model-mark-func-visibility -split-input-file %s | FileCheck --check-prefix=SAVEDMODEL %s -// RUN: tf-opt -tf-mark-func-visibility -split-input-file -verify-diagnostics %s | FileCheck %s - - -module attributes {tf_saved_model.semantics} { - // SAVEDMODEL: func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} - func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} { - "tf.some_call"() {callee = {callee = {callee = @child}}} : () -> () - return - } - - // SAVEDMODEL: func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} - func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} { - "tf.some_call"() {callee = {callee = {callee = @child}}} : () -> () - return - } - - // SAVEDMODEL: func @func_not_exported() attributes {sym_visibility = "private"} - func @func_not_exported() { - return - } - -} - -// ----- - -module { - // CHECK: func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} - func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} { - return %arg0 : tensor<1xi32> - } - - // CHECK: func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {sym_visibility = "private"} - func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf.AddV2"(%arg0, %arg1) {T = i32, device = ""} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - return %0 : tensor<*xi32> - } -} - -// ----- - -module { - // expected-error @+1 {{can't overwrite the visibility of function private_func_with_entry_spec with private visibility}} - func @private_func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}, sym_visibility = "private"} { - return %arg0 : tensor<1xi32> - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir deleted file mode 100644 index 6f2c47a935f..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir +++ /dev/null @@ -1,96 +0,0 @@ -// RUN: tf-opt -tf-saved-model-mark-func-visibility -symbol-dce -split-input-file %s | FileCheck %s - -module attributes {tf_saved_model.semantics} { - - // Test case: Unused function should be deleted. - - // CHECK-NOT: func @unused - func @unused() { - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Root calls child. Child should not be deleted. - - // CHECK: func @root - func @root() attributes {tf_saved_model.exported_names = ["root"]} { - "tf.some_call"() { callee = @child } : () -> () - return - } - - // CHECK: func @child - func @child() { - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Don't crash if attribute that doesn't reference a func. - - "tf.some_opaque_global_variable"() { sym_name = "some_global" } : () -> () - - func @root2() attributes {tf_saved_model.exported_names = ["root2"]} { - "tf.do_something_with_a_global"() { global = @some_global } : () -> () - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Delete recursively dead cycle. - - // CHECK-NOT: func @recursively_dead0 - func @recursively_dead0() { - "tf.some_call"() { callee = @recursively_dead1 } : () -> () - return - } - // CHECK-NOT: func @recursively_dead1 - func @recursively_dead1() { - "tf.some_call"() { callee = @recursively_dead0 } : () -> () - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Root calls child with a deeply nested symbol reference. - // Child should not be deleted. - - // CHECK: func @root - func @root() attributes {tf_saved_model.exported_names = ["root"]} { - "tf.some_call"() {callee = {callee = {callee = @child}}} : () -> () - return - } - - // CHECK: func @child - func @child() { - return - } - -} - -// ----- - -// Test case: If the module doesn't have tf_saved_model semantics, then this -// pass shouldn't do anything. -module { - // CHECK: func @not_dead() - func @not_dead() { - return - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index 38627b41b68..6c32a3bc4d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -64,7 +64,7 @@ module attributes {tf_saved_model.semantics} { return } - func @f_callee(%arg0: tensor>>) { + func @f_callee(%arg0: tensor>>) attributes {sym_visibility = "private"} { return } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index aa1f996da07..05e7638645f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -40,7 +40,7 @@ module attributes {tf_saved_model.semantics} { return %arg0 : tensor } - func @f() { + func @f() attributes {sym_visibility = "private"} { return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index 544600cf6b8..f04e1a60b36 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -3,7 +3,7 @@ module attributes {tf_saved_model.semantics} { // expected-error@+1 {{unknown tf_saved_model dialect arg attribute 'tf_saved_model.not_a_real_arg_attr'}} - func @f(%arg0: tensor {tf_saved_model.not_a_real_arg_attr = 1 : i32}) { + func @f(%arg0: tensor {tf_saved_model.not_a_real_arg_attr = 1 : i32}) attributes {sym_visibility = "private"} { return } @@ -233,7 +233,7 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.> : tensor<1xf32> } : () -> () // expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}} func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) - -> (tensor {tf_saved_model.index_path = []}) { + -> (tensor {tf_saved_model.index_path = []}) attributes {sym_visibility = "private"} { %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor return %0 : tensor } @@ -273,7 +273,7 @@ module attributes {tf_saved_model.semantics} { // expected-error@+1 {{the initializer function should have no output}} "tf_saved_model.session_initializer"() { initializer = @init } : () -> () - func @init() -> tensor<1xf32> { + func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} { %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> return %0 : tensor<1xf32> } @@ -286,8 +286,34 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.session_initializer"() { initializer = @init } : () -> () // expected-error@+1 {{there must be no more than one session_initializer op}} "tf_saved_model.session_initializer"() { initializer = @init } : () -> () - func @init() -> tensor<1xf32> { + func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} { %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> return %0 : tensor<1xf32> } } + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{exported function @f should be public}} + func @f( + %arg0: tensor {tf.resource_name = "resource"} + ) attributes { sym_visibility = "private", tf_saved_model.exported_names = ["foo.some_func"] } { + return + } + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{non-exported function @f should be private}} + func @f( + %arg0: tensor {tf.resource_name = "resource"} + ) { + return + } + +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir index 91e8c9c4b66..14a0006cd3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir @@ -20,12 +20,12 @@ module attributes {tf_saved_model.semantics} { return %val : tensor } - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } - func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor return %val : tensor } @@ -59,7 +59,7 @@ module attributes {tf_saved_model.semantics} { return %val : tensor } - func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor return %val : tensor } @@ -85,7 +85,7 @@ module attributes {tf_saved_model.semantics} { return %val_2 : tensor } - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %cst_1 = constant dense<2.0> : tensor return %cst_1 : tensor } @@ -112,13 +112,13 @@ module attributes {tf_saved_model.semantics} { } // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %c0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor) -> () return %c0 : tensor @@ -146,13 +146,13 @@ module attributes {tf_saved_model.semantics} { } // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %c0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor) -> () return %c0 : tensor @@ -179,13 +179,13 @@ module attributes {tf_saved_model.semantics} { // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor - func @f(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } // CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor - func @g(%arg0: tensor<*x!tf.resource>) -> tensor { + func @g(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } @@ -212,7 +212,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor - func @f(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %c0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor "tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor) -> () return %c0 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc deleted file mode 100644 index 31a80a4ecdb..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" - -#define DEBUG_TYPE "tf-shape-inference" - -namespace mlir { - -namespace { - -LogicalResult MarkFunctionVisibility( - ModuleOp module, llvm::function_ref IsExternalVisible) { - LogicalResult result = success(); - - for (auto func : module.getOps()) { - FuncOp::Visibility old_visibility = func.getVisibility(); - - FuncOp::Visibility visibility = IsExternalVisible(func) - ? FuncOp::Visibility::Public - : FuncOp::Visibility::Private; - - auto get_visibility_name = [](FuncOp::Visibility v) { - return v == FuncOp::Visibility::Public - ? "public" - : v == FuncOp::Visibility::Private ? "private" : "nested"; - }; - - if (old_visibility != SymbolTable::Visibility::Public && - old_visibility != visibility) { - result = func.emitError() - << "can't overwrite the visibility of function " - << func.getName() << " with " - << get_visibility_name(old_visibility) << " visibility"; - } - - LLVM_DEBUG(llvm::dbgs() - << "function " << func.getName() << " has " - << get_visibility_name(visibility) << " visibility \n"); - - func.setVisibility(visibility); - } - - return result; -} - -} // anonymous namespace - -namespace TF { - -LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( - ModuleOp module) { - auto HasEntryFunctionSpecification = [](FuncOp func) -> bool { - auto attrs = func.getAttrOfType("tf.entry_function"); - return attrs && !attrs.empty(); - }; - return MarkFunctionVisibility(module, HasEntryFunctionSpecification); -} - -namespace { -struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass - : public PassWrapper< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, - OperationPass> { - void runOnOperation() override { - if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification( - getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -static PassRegistration< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass> - pass("tf-mark-func-visibility", - "Use tf.entry_function to mark function visibility."); - -std::unique_ptr> -CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() { - return std::make_unique< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>(); -} - -// Marks the main function with public visibility, while other functions are -// marked with private visibility. -LogicalResult MarkOnlyMainFunctionWithPublicVisibility(ModuleOp module) { - for (auto func : module.getOps()) { - if (func.getName() == "main") { - func.setVisibility(FuncOp::Visibility::Public); - } else { - func.setVisibility(FuncOp::Visibility::Private); - } - } - return success(); -} - -namespace { -struct MarkOnlyMainFunctionWithPublicVisibilityPass - : public PassWrapper> { - void runOnOperation() override { - if (failed(MarkOnlyMainFunctionWithPublicVisibility(getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr> -CreateMarkOnlyMainFunctionWithPublicVisibilityPass() { - return std::make_unique(); -} - -} // namespace TF - -namespace tf_saved_model { - -static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage( - ModuleOp module) { - if (!tf_saved_model::HasTfSavedModelSemantics(module)) { - return success(); - } - return MarkFunctionVisibility(module, tf_saved_model::IsExported); -} - -namespace { -struct MarkFunctionVisibilityUsingSavedModelLinkagePass - : public PassWrapper> { - void runOnOperation() override { - if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -static PassRegistration pass( - "tf-saved-model-mark-func-visibility", - "Use tf_saved_model linkage information to mark function visibility."); - -std::unique_ptr> -CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() { - return std::make_unique(); -} - -} // namespace tf_saved_model - -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 7158d0f6be0..5cb15027fc5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -117,21 +117,6 @@ std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); std::unique_ptr> CreateConvertReadonlyReferenceVariablesToResourceVariablesPass(); -// Marks function visibility using tf.entry_function specification. That is, -// functions with tf.entry_function attributes are marked with public -// visibility while the other functions are marked with private visibility. -LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( - ModuleOp module); -// Creates a pass that uses tf.entry_function specification to mark function -// visibility. -std::unique_ptr> -CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass(); - -// Creates a pass that marks the main function with public visibility, while -// other functions are marked with private visibility. -std::unique_ptr> -CreateMarkOnlyMainFunctionWithPublicVisibilityPass(); - // Creates a simple device assignment pass on TF dialect for CoreRT use case. std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( llvm::StringRef default_device); @@ -315,13 +300,6 @@ std::unique_ptr> CreateOptimizeGlobalTensorsPass(); // Creates a pass that freezes tf_saved_model.global_tensor ops. std::unique_ptr> CreateFreezeGlobalTensorsPass(); -// Creates a pass that uses tf_saved_model dialect linkage information -// to mark function visibility. That is, exported functions are marked with -// public visibility while the other functions are marked with private -// visibility. -std::unique_ptr> -CreateMarkFunctionVisibilityUsingSavedModelLinkagePass(); - } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 696882cd105..ec9b3df525f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -146,6 +146,9 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, // We can simply change name of TPU program's main function because there // should be no other reference to it. clone.setName("main"); + clone.setVisibility(FuncOp::Visibility::Public); + } else { + clone.setVisibility(FuncOp::Visibility::Private); } symbol_table.insert(clone); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index fd1ba3b1901..dac2fea87e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -267,9 +267,6 @@ Status ConvertMLIRToXlaComputation( const XlaCompiler::ShapeRepresentationFn shape_representation_fn, std::vector> custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); - // Mark main function as public, and other functions as private. - tf2xla.addPass( - mlir::TF::CreateMarkOnlyMainFunctionWithPublicVisibilityPass()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 43793be56a7..60d1f3da0c5 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -165,11 +165,6 @@ Status ConvertGraphDefToXlaViaMlir( device_set.AddDevice(&device); AddDevicesToOp(*module, &device_set); - if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification( - *module))) { - return errors::Internal("Problem with mark function visibility"); - } - TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline( *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true));