[MLIR] Determine function visibility during import
- Modify GraphDef and TFLite flatbuffer importers to determine the visibility of functions during the import - Fixed OptimizeFunctionalOpsPass just optimize functional If ops and not erase any functions. Add another SymbolDCE Pass after OptimizeFunctionalOps pass to handle deleting any dead functions. - This eliminates the need to mark visibility using the saved model linkage in the Tf -> TFLite pass pipeline. PiperOrigin-RevId: 316145906 Change-Id: I3eb698cef7ed96e11ee2f1f24c698986fcba02c2
This commit is contained in:
parent
9401f80281
commit
14b5803e26
tensorflow/compiler/mlir
@ -868,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
subgraph, &builder, "outputs", func_outputs));
|
||||
}
|
||||
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
||||
} else {
|
||||
func.setVisibility(FuncOp::Visibility::Private);
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
|
@ -13,14 +13,12 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||
return %3 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-NOT: add
|
||||
// CHECK-NOT: sub
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -42,65 +40,31 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||
return %3 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-NOT: addormul
|
||||
// CHECK-NOT: sub
|
||||
// CHECK-NOT: mul
|
||||
// CHECK-NOT: add
|
||||
|
||||
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = constant dense<false> : tensor<i1>
|
||||
%1 = "tf.If"(%0, %arg1, %arg0) {else_branch = @mul, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Multiply"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify that branch functions with multiple references are not erased.
|
||||
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>, tensor<f32>) {
|
||||
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
%2 = constant dense<true> : tensor<i1>
|
||||
|
||||
// CHECK: tf.Add
|
||||
%3 = "tf.If"(%2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: tf.If
|
||||
%4 = "tf.If"(%arg2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %3, %4 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: add
|
||||
// CHECK: sub
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify unused if with functions without side-effects are removed.
|
||||
|
||||
// Verify unused if with functions without side-effects is removed.
|
||||
// CHECK-LABEL: main
|
||||
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
|
||||
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
@ -118,26 +82,22 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
return %4 : tensor<3x15x14x8xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<false> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<true> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK-NOT: tf.If
|
||||
// CHECK: return
|
||||
// CHECK-NOT: func @_functionalize_if_else_branch_00
|
||||
// CHECK-NOT: func @_functionalize_if_then_branch_00
|
||||
|
||||
// -----
|
||||
|
||||
// Verify unused if with function with side-effects is not removed.
|
||||
|
||||
// CHECK-LABEL: main
|
||||
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
|
||||
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
@ -155,27 +115,25 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
return %4 : tensor<3x15x14x8xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<false> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.blah"() : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: tf.If
|
||||
// CHECK: return
|
||||
// CHECK: func @_functionalize_if_else_branch_01
|
||||
// CHECK: func @_functionalize_if_then_branch_01
|
||||
|
||||
// -----
|
||||
|
||||
// Verify unused if with function with side-effects is removed if op says
|
||||
// stateless.
|
||||
|
||||
// CHECK-LABEL: main
|
||||
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
|
||||
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
@ -193,18 +151,15 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
return %4 : tensor<3x15x14x8xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<false> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.blah"() : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK-NOT: tf.If
|
||||
// CHECK: return
|
||||
// CHECK-NOT: func @_functionalize_if_else_branch_02
|
||||
// CHECK-NOT: func @_functionalize_if_then_branch_02
|
||||
|
@ -94,12 +94,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
}
|
||||
|
||||
// This pass marks non-exported functions as symbol visibility 'private'
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
|
||||
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
@ -162,6 +156,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// so that it can target constants introduced once TensorFlow Identity ops
|
||||
// are removed during legalization.
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
// This pass should be always at the end of the floating point model
|
||||
@ -237,6 +232,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
|
||||
pm.addPass(mlir::TFL::CreateOptimizePass());
|
||||
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
// Canonicalize, CSE etc.
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
@ -32,8 +32,6 @@ namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
using FuncSet = llvm::SmallSet<FuncOp, 4>;
|
||||
|
||||
// Module pass to optimize TensorFlow functional ops.
|
||||
struct OptimizeFunctionalOpsPass
|
||||
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
|
||||
@ -44,8 +42,8 @@ struct OptimizeFunctionalOpsPass
|
||||
// op operands' types.
|
||||
//
|
||||
// Requires the function has exactly one block.
|
||||
static void UpdateFuncType(FuncOp func) {
|
||||
Operation* terminator = &func.getBlocks().front().back();
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
Operation* terminator = func.front().getTerminator();
|
||||
auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
|
||||
|
||||
FunctionType func_type = func.getType();
|
||||
@ -57,7 +55,7 @@ static void UpdateFuncType(FuncOp func) {
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Remove when recursive side-effect modeling is added.
|
||||
static bool IsSideEffectFree(FuncOp func) {
|
||||
bool IsSideEffectFree(FuncOp func) {
|
||||
return !func.getBody()
|
||||
.walk([&](Operation* op) {
|
||||
if (!MemoryEffectOpInterface::hasNoEffect(op) &&
|
||||
@ -72,8 +70,8 @@ static bool IsSideEffectFree(FuncOp func) {
|
||||
// function body based on the conditional value.
|
||||
class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
public:
|
||||
explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs)
|
||||
: OpRewritePattern<TF::IfOp>(context), inlined_funcs_(inlined_funcs) {}
|
||||
explicit FoldIfOp(MLIRContext* context)
|
||||
: OpRewritePattern<TF::IfOp>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(TF::IfOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
@ -82,7 +80,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// updated if operands' shapes change after inlining. Without this
|
||||
// restriction, it would require tensor cast ops.
|
||||
FuncOp parent_op = op.getParentOfType<FuncOp>();
|
||||
if (parent_op.getBlocks().size() != 1) return failure();
|
||||
if (!llvm::hasSingleElement(parent_op)) return failure();
|
||||
|
||||
// Find the then and else branch functions.
|
||||
SymbolTable table(op.getParentOfType<ModuleOp>());
|
||||
@ -95,8 +93,6 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
if (op.use_empty() &&
|
||||
(op.is_stateless() ||
|
||||
(IsSideEffectFree(then_branch) && IsSideEffectFree(else_branch)))) {
|
||||
inlined_funcs_->insert(then_branch);
|
||||
inlined_funcs_->insert(else_branch);
|
||||
rewriter.eraseOp(op.getOperation());
|
||||
return success();
|
||||
}
|
||||
@ -118,14 +114,14 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// Make sure that the function has exactly one block to simplify inlining.
|
||||
// TFLite doesn't use control flow with blocks so functions with more than
|
||||
// one blocks are not encountered in practice.
|
||||
if (func.getBody().getBlocks().size() != 1) return failure();
|
||||
if (!llvm::hasSingleElement(func)) return failure();
|
||||
|
||||
BlockAndValueMapping mapper;
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
|
||||
mapper.map(func.getArgument(i), op.getOperand(i + 1));
|
||||
|
||||
llvm::SmallVector<Value, 4> updated_results;
|
||||
for (auto& op_to_inline : func.getBody().front()) {
|
||||
for (auto& op_to_inline : func.front()) {
|
||||
// If this is a terminator, identify the values to use to replace the
|
||||
// original If op.
|
||||
if (op_to_inline.isKnownTerminator()) {
|
||||
@ -145,64 +141,26 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// return type should be updated.
|
||||
UpdateFuncType(parent_op);
|
||||
|
||||
// Track functions that could be erased if this op was the last reference
|
||||
// of the function.
|
||||
inlined_funcs_->insert(then_branch);
|
||||
inlined_funcs_->insert(else_branch);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
FuncSet* inlined_funcs_;
|
||||
};
|
||||
|
||||
// Erases functions from the given candidates that are not referenced by any of
|
||||
// the ops in the module.
|
||||
static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
|
||||
if (candidate_funcs.empty()) return;
|
||||
|
||||
SymbolTable manager(module);
|
||||
|
||||
// Identify the functions that are used as symbols in the module and shouldn't
|
||||
// be erased.
|
||||
FuncSet in_use_funcs;
|
||||
manager.getOp()->walk([&](Operation* op) {
|
||||
for (auto attr : op->getAttrs()) {
|
||||
if (auto symbol = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
|
||||
auto func = manager.lookup<FuncOp>(symbol.getValue());
|
||||
in_use_funcs.insert(func);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
for (FuncOp func : candidate_funcs) {
|
||||
if (!in_use_funcs.count(func)) manager.erase(func);
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizeFunctionalOpsPass::runOnOperation() {
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
FuncSet inlined_funcs;
|
||||
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
||||
patterns.insert<FoldIfOp>(&getContext());
|
||||
|
||||
ModuleOp module = getOperation();
|
||||
applyPatternsAndFoldGreedily(module, patterns);
|
||||
|
||||
// Erase inlined functions that don't have any references.
|
||||
//
|
||||
// TODO(hinsu): Update this to not erase entry points once TFLite support to
|
||||
// have multiple entry points is implemented. Until then, it is safe to
|
||||
// erase these functions.
|
||||
EraseDeadFuncs(inlined_funcs, module);
|
||||
}
|
||||
|
||||
PassRegistration<OptimizeFunctionalOpsPass> pass(
|
||||
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
|
||||
return std::make_unique<OptimizeFunctionalOpsPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<OptimizeFunctionalOpsPass> pass(
|
||||
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
@ -49,5 +49,5 @@ library {
|
||||
}
|
||||
}
|
||||
|
||||
# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>}
|
||||
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
|
||||
# CHECK-DAG: func @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
|
||||
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
|
||||
|
@ -124,5 +124,5 @@ versions {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
|
||||
|
||||
# CHECK-LABEL: func @foo110() {
|
||||
# CHECK-LABEL: func @foo111() {
|
||||
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}
|
||||
|
@ -57,7 +57,7 @@ versions {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @foo0() {
|
||||
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @bar0() {
|
||||
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}
|
||||
|
@ -219,8 +219,7 @@ class ImporterBase {
|
||||
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
|
||||
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
|
||||
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs,
|
||||
bool function_graph);
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs);
|
||||
|
||||
// Finds out the function definition for the given function name from the
|
||||
// graph and converts it to a function of the module. This method is called
|
||||
@ -1302,8 +1301,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
|
||||
TF_RETURN_IF_ERROR(child_importer.Convert(
|
||||
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
|
||||
llvm::makeArrayRef(attributes.begin(), attributes.end()),
|
||||
/*function_graph=*/true));
|
||||
llvm::makeArrayRef(attributes.begin(), attributes.end())));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1405,7 +1403,7 @@ Status ImporterBase::Convert(
|
||||
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
|
||||
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
|
||||
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs, bool function_graph) {
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// TODO(b/122040776): Uses debug info for FunctionDef.
|
||||
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
|
||||
func_name, func_type, attrs);
|
||||
@ -2222,8 +2220,15 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
PopulateTfVersions(module.get(), graph.versions());
|
||||
|
||||
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
|
||||
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
|
||||
specs.graph_as_function));
|
||||
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
|
||||
|
||||
// Mark main function public, others private.
|
||||
for (auto function : module.get().getOps<mlir::FuncOp>()) {
|
||||
auto visibility = function.getName() == func_name
|
||||
? mlir::FuncOp::Visibility::Public
|
||||
: mlir::FuncOp::Visibility::Private;
|
||||
function.setVisibility(visibility);
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
@ -2888,6 +2893,16 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
|
||||
}
|
||||
}
|
||||
|
||||
// Marks the visibility of functions in the saved model module.
|
||||
void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
|
||||
for (auto func : module.getOps<mlir::FuncOp>()) {
|
||||
auto visibility = mlir::tf_saved_model::IsExported(func)
|
||||
? mlir::FuncOp::Visibility::Public
|
||||
: mlir::FuncOp::Visibility::Private;
|
||||
func.setVisibility(visibility);
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder the ops in the module to make testing easier and less dependent
|
||||
// on implementation details such as the order of functions in the
|
||||
// FunctionDefLibrary.
|
||||
@ -3130,6 +3145,7 @@ Status CreateSavedModelIR(
|
||||
AdjustBoundInputArgTypes(module);
|
||||
module.setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
SortSavedModelModule(module);
|
||||
MarkSavedModelFunctionVisibility(module);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -3299,6 +3315,7 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
mlir::OpBuilder builder(module_->getBodyRegion());
|
||||
module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
SortSavedModelModule(*module_);
|
||||
MarkSavedModelFunctionVisibility(*module_);
|
||||
|
||||
return std::move(module_);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user