[MLIR] Determine function visibility during import

- Modify GraphDef and TFLite flatbuffer importers to determine the visibility
  of functions during the import
- Fixed OptimizeFunctionalOpsPass just optimize functional If ops and not erase
  any functions. Add another SymbolDCE Pass after OptimizeFunctionalOps pass to
  handle deleting any dead functions.
- This eliminates the need to mark visibility using the saved model linkage in the
  Tf -> TFLite pass pipeline.

PiperOrigin-RevId: 316145906
Change-Id: I3eb698cef7ed96e11ee2f1f24c698986fcba02c2
This commit is contained in:
Rahul Joshi 2020-06-12 11:45:55 -07:00 committed by TensorFlower Gardener
parent 9401f80281
commit 14b5803e26
8 changed files with 62 additions and 134 deletions

View File

@ -868,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
subgraph, &builder, "outputs", func_outputs));
}
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;

View File

@ -13,14 +13,12 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
return %3 : tensor<f32>
}
// CHECK-NOT: add
// CHECK-NOT: sub
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -42,65 +40,31 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
return %3 : tensor<f32>
}
// CHECK-NOT: addormul
// CHECK-NOT: sub
// CHECK-NOT: mul
// CHECK-NOT: add
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = constant dense<false> : tensor<i1>
%1 = "tf.If"(%0, %arg1, %arg0) {else_branch = @mul, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Multiply"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify that branch functions with multiple references are not erased.
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>, tensor<f32>) {
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
%2 = constant dense<true> : tensor<i1>
// CHECK: tf.Add
%3 = "tf.If"(%2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tf.If
%4 = "tf.If"(%arg2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %3, %4 : tensor<f32>, tensor<f32>
}
// CHECK: add
// CHECK: sub
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify unused if with functions without side-effects are removed.
// Verify unused if with functions without side-effects is removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -118,26 +82,22 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<true> : tensor<i1>
return %cst : tensor<i1>
}
// CHECK: func @main
// CHECK-NOT: tf.If
// CHECK: return
// CHECK-NOT: func @_functionalize_if_else_branch_00
// CHECK-NOT: func @_functionalize_if_then_branch_00
// -----
// Verify unused if with function with side-effects is not removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -155,27 +115,25 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK: func @main
// CHECK: tf.If
// CHECK: return
// CHECK: func @_functionalize_if_else_branch_01
// CHECK: func @_functionalize_if_then_branch_01
// -----
// Verify unused if with function with side-effects is removed if op says
// stateless.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -193,18 +151,15 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK: func @main
// CHECK-NOT: tf.If
// CHECK: return
// CHECK-NOT: func @_functionalize_if_else_branch_02
// CHECK-NOT: func @_functionalize_if_then_branch_02

View File

@ -94,12 +94,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
@ -162,6 +156,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// so that it can target constants introduced once TensorFlow Identity ops
// are removed during legalization.
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
// This pass should be always at the end of the floating point model
@ -237,6 +232,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
pm.addPass(mlir::TFL::CreateOptimizePass());
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pm.addPass(mlir::createSymbolDCEPass());
// Canonicalize, CSE etc.
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());

View File

@ -32,8 +32,6 @@ namespace mlir {
namespace TFL {
namespace {
using FuncSet = llvm::SmallSet<FuncOp, 4>;
// Module pass to optimize TensorFlow functional ops.
struct OptimizeFunctionalOpsPass
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
@ -44,8 +42,8 @@ struct OptimizeFunctionalOpsPass
// op operands' types.
//
// Requires the function has exactly one block.
static void UpdateFuncType(FuncOp func) {
Operation* terminator = &func.getBlocks().front().back();
void UpdateFuncType(FuncOp func) {
Operation* terminator = func.front().getTerminator();
auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
FunctionType func_type = func.getType();
@ -57,7 +55,7 @@ static void UpdateFuncType(FuncOp func) {
}
// TODO(jpienaar): Remove when recursive side-effect modeling is added.
static bool IsSideEffectFree(FuncOp func) {
bool IsSideEffectFree(FuncOp func) {
return !func.getBody()
.walk([&](Operation* op) {
if (!MemoryEffectOpInterface::hasNoEffect(op) &&
@ -72,8 +70,8 @@ static bool IsSideEffectFree(FuncOp func) {
// function body based on the conditional value.
class FoldIfOp : public OpRewritePattern<TF::IfOp> {
public:
explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs)
: OpRewritePattern<TF::IfOp>(context), inlined_funcs_(inlined_funcs) {}
explicit FoldIfOp(MLIRContext* context)
: OpRewritePattern<TF::IfOp>(context) {}
LogicalResult matchAndRewrite(TF::IfOp op,
PatternRewriter& rewriter) const override {
@ -82,7 +80,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// updated if operands' shapes change after inlining. Without this
// restriction, it would require tensor cast ops.
FuncOp parent_op = op.getParentOfType<FuncOp>();
if (parent_op.getBlocks().size() != 1) return failure();
if (!llvm::hasSingleElement(parent_op)) return failure();
// Find the then and else branch functions.
SymbolTable table(op.getParentOfType<ModuleOp>());
@ -95,8 +93,6 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
if (op.use_empty() &&
(op.is_stateless() ||
(IsSideEffectFree(then_branch) && IsSideEffectFree(else_branch)))) {
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
rewriter.eraseOp(op.getOperation());
return success();
}
@ -118,14 +114,14 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// Make sure that the function has exactly one block to simplify inlining.
// TFLite doesn't use control flow with blocks so functions with more than
// one blocks are not encountered in practice.
if (func.getBody().getBlocks().size() != 1) return failure();
if (!llvm::hasSingleElement(func)) return failure();
BlockAndValueMapping mapper;
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
mapper.map(func.getArgument(i), op.getOperand(i + 1));
llvm::SmallVector<Value, 4> updated_results;
for (auto& op_to_inline : func.getBody().front()) {
for (auto& op_to_inline : func.front()) {
// If this is a terminator, identify the values to use to replace the
// original If op.
if (op_to_inline.isKnownTerminator()) {
@ -145,64 +141,26 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// return type should be updated.
UpdateFuncType(parent_op);
// Track functions that could be erased if this op was the last reference
// of the function.
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
return success();
}
private:
FuncSet* inlined_funcs_;
};
// Erases functions from the given candidates that are not referenced by any of
// the ops in the module.
static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
if (candidate_funcs.empty()) return;
SymbolTable manager(module);
// Identify the functions that are used as symbols in the module and shouldn't
// be erased.
FuncSet in_use_funcs;
manager.getOp()->walk([&](Operation* op) {
for (auto attr : op->getAttrs()) {
if (auto symbol = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
auto func = manager.lookup<FuncOp>(symbol.getValue());
in_use_funcs.insert(func);
}
}
});
for (FuncOp func : candidate_funcs) {
if (!in_use_funcs.count(func)) manager.erase(func);
}
}
void OptimizeFunctionalOpsPass::runOnOperation() {
OwningRewritePatternList patterns;
FuncSet inlined_funcs;
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
patterns.insert<FoldIfOp>(&getContext());
ModuleOp module = getOperation();
applyPatternsAndFoldGreedily(module, patterns);
// Erase inlined functions that don't have any references.
//
// TODO(hinsu): Update this to not erase entry points once TFLite support to
// have multiple entry points is implemented. Until then, it is safe to
// erase these functions.
EraseDeadFuncs(inlined_funcs, module);
}
PassRegistration<OptimizeFunctionalOpsPass> pass(
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
return std::make_unique<OptimizeFunctionalOpsPass>();
}
static PassRegistration<OptimizeFunctionalOpsPass> pass(
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
} // namespace TFL
} // namespace mlir

View File

@ -49,5 +49,5 @@ library {
}
}
# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
# CHECK-DAG: func @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}

View File

@ -124,5 +124,5 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
# CHECK-LABEL: func @foo110() {
# CHECK-LABEL: func @foo111() {
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}

View File

@ -57,7 +57,7 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @foo0() {
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @bar0() {
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}

View File

@ -219,8 +219,7 @@ class ImporterBase {
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs,
bool function_graph);
llvm::ArrayRef<mlir::NamedAttribute> attrs);
// Finds out the function definition for the given function name from the
// graph and converts it to a function of the module. This method is called
@ -1302,8 +1301,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
TF_RETURN_IF_ERROR(child_importer.Convert(
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
llvm::makeArrayRef(attributes.begin(), attributes.end()),
/*function_graph=*/true));
llvm::makeArrayRef(attributes.begin(), attributes.end())));
return Status::OK();
}
@ -1405,7 +1403,7 @@ Status ImporterBase::Convert(
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs, bool function_graph) {
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// TODO(b/122040776): Uses debug info for FunctionDef.
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
func_name, func_type, attrs);
@ -2222,8 +2220,15 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
PopulateTfVersions(module.get(), graph.versions());
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
specs.graph_as_function));
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
// Mark main function public, others private.
for (auto function : module.get().getOps<mlir::FuncOp>()) {
auto visibility = function.getName() == func_name
? mlir::FuncOp::Visibility::Public
: mlir::FuncOp::Visibility::Private;
function.setVisibility(visibility);
}
return module;
}
@ -2888,6 +2893,16 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
}
}
// Marks the visibility of functions in the saved model module.
void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
for (auto func : module.getOps<mlir::FuncOp>()) {
auto visibility = mlir::tf_saved_model::IsExported(func)
? mlir::FuncOp::Visibility::Public
: mlir::FuncOp::Visibility::Private;
func.setVisibility(visibility);
}
}
// Reorder the ops in the module to make testing easier and less dependent
// on implementation details such as the order of functions in the
// FunctionDefLibrary.
@ -3130,6 +3145,7 @@ Status CreateSavedModelIR(
AdjustBoundInputArgTypes(module);
module.setAttr("tf_saved_model.semantics", builder.getUnitAttr());
SortSavedModelModule(module);
MarkSavedModelFunctionVisibility(module);
return Status::OK();
}
@ -3299,6 +3315,7 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
mlir::OpBuilder builder(module_->getBodyRegion());
module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
SortSavedModelModule(*module_);
MarkSavedModelFunctionVisibility(*module_);
return std::move(module_);
}