Add MergeControlFlow pass, partially implemented.
Determines whether IfRegions can safely be merged together and merges then. PiperOrigin-RevId: 346162444 Change-Id: I79481e0d1b1fc2b56e4158efc3d7874f8c22b879
This commit is contained in:
parent
295b3dcc3d
commit
07bb67303b
@ -889,6 +889,7 @@ cc_library(
|
||||
"transforms/layout_optimization.cc",
|
||||
"transforms/mark_ops_for_outside_compilation.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/merge_control_flow.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/outside_compiled_to_host_launch.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
|
@ -0,0 +1,175 @@
|
||||
// RUN: tf-opt %s -tf-merge-control-flow | FileCheck %s
|
||||
|
||||
// Check that IfRegions with different predicates are not merged.
|
||||
|
||||
// CHECK-LABEL: func @different_predicate_no_merge
|
||||
func @different_predicate_no_merge() {
|
||||
// CHECK: tf_device.cluster
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.IfRegion"
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.A"() : () -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> ()
|
||||
"tf.IfRegion"(%1) ( {
|
||||
%2 = "tf.B"() : () -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that IfRegions with same predicates but different block are not merged.
|
||||
|
||||
// CHECK-LABEL: func @different_block_no_merge
|
||||
func @different_block_no_merge() {
|
||||
// CHECK: tf_device.cluster
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.IfRegion"
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%1 = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
|
||||
%3 = "tf.A"() : () -> (tensor<?xf32>)
|
||||
%4 = "tf.B"() : () -> (tensor<i32>)
|
||||
"tf.WhileRegion"(%4, %3) ({
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.A"() : () -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> ()
|
||||
"tf.Yield"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.B"() : () -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> ()
|
||||
"tf.Yield"(%arg1, %arg2) : (tensor<i32>, tensor<?xf32>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that IfRegions with same predicates and no returns are merged.
|
||||
|
||||
// CHECK-LABEL: func @same_predicate_no_returns_merged
|
||||
func @same_predicate_no_returns_merged() {
|
||||
// CHECK: tf_device.cluster
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK-NOT: "tf.IfRegion"
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.A"() : () -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> ()
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.B"() : () -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that IfRegions with same predicate intermediate data dependency are not merged.
|
||||
|
||||
// CHECK-LABEL: func @same_predicate_intermediate_dependency_no_merge
|
||||
func @same_predicate_intermediate_dependency_no_merge() {
|
||||
// CHECK: tf_device.cluster
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.IfRegion"
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%1 = "tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.A"() : () -> (tensor<f32>)
|
||||
"tf.Yield"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
%2 = "tf.C"() : () -> (tensor<f32>)
|
||||
"tf.Yield"(%2) : (tensor<f32>) -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
|
||||
%3 = "tf.D"(%1) : (tensor<f32>) -> (tensor<f32>)
|
||||
%4 = "tf.E"(%3) : (tensor<f32>) -> (tensor<f32>)
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%5 = "tf.B"(%4) : (tensor<f32>) -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that IfRegions with same predicate intermediate side effect dependency are not merged.
|
||||
|
||||
// CHECK-LABEL: func @same_predicate_side_effect_dependency_no_merge
|
||||
func @same_predicate_side_effect_dependency_no_merge() {
|
||||
// CHECK: tf_device.cluster
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: "tf.IfRegion"
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%1 = "tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.A"() : () -> (tensor<f32>)
|
||||
"tf.Yield"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
%2 = "tf.C"() : () -> (tensor<f32>)
|
||||
"tf.Yield"(%2) : (tensor<f32>) -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
|
||||
"tf.D"(%1) : (tensor<f32>) -> ()
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%4 = "tf.B"(%1) : (tensor<f32>) -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that merged IfRegions correctly set is_stateless attribute.
|
||||
|
||||
// CHECK-LABEL: func @same_predicate_stateless_merge
|
||||
func @same_predicate_stateless_merge() {
|
||||
// CHECK: tf_device.cluster
|
||||
// CHECK: "tf.IfRegion"
|
||||
// CHECK: is_stateless = false
|
||||
// CHECK-NOT: "tf.IfRegion"
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
|
||||
%1 = "tf.IfRegion"(%0) ( {
|
||||
%2 = "tf.A"() : () -> (tensor<f32>)
|
||||
"tf.Yield"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
%2 = "tf.C"() : () -> (tensor<f32>)
|
||||
"tf.Yield"(%2) : (tensor<f32>) -> ()
|
||||
}) {is_stateless = true} : (tensor<i1>) -> (tensor<f32>)
|
||||
"tf.IfRegion"(%0) ( {
|
||||
%4 = "tf.B"() : () -> (tensor<f32>)
|
||||
"tf.Yield"() : () -> ()
|
||||
}, {
|
||||
"tf.Yield"() : () -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
@ -0,0 +1,230 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFDevice {
|
||||
|
||||
namespace {
|
||||
|
||||
// This pass merges IfRegion ops together if they have the same predicate and it
|
||||
// is safe to do so (there are no intermediate dependencies, they are in the
|
||||
// same block, etc).
|
||||
//
|
||||
// A simple example:
|
||||
// "tf.IfRegion"(%0) ( {
|
||||
// %2 = "tf.A"() : () -> (tensor<f32>)
|
||||
// "tf.Yield"() : () -> ()
|
||||
// }, {
|
||||
// "tf.Yield"() : () -> ()
|
||||
// }) { is_stateless = true } : (tensor<i1>) -> ()
|
||||
// "tf.IfRegion"(%0) ( {
|
||||
// %2 = "tf.B"() : () -> (tensor<f32>)
|
||||
// "tf.Yield"() : () -> ()
|
||||
// }, {
|
||||
// "tf.Yield"() : () -> ()
|
||||
// }) { is_stateless = true } : (tensor<i1>) -> ()
|
||||
// Would become:
|
||||
// "tf.IfRegion"(%0) ( {
|
||||
// %2 = "tf.A"() : () -> (tensor<f32>)
|
||||
// %3 = "tf.B"() : () -> (tensor<f32>)
|
||||
// "tf.Yield"() : () -> ()
|
||||
// }, {
|
||||
// "tf.Yield"() : () -> ()
|
||||
// }) { is_stateless = true } : (tensor<i1>) -> ()
|
||||
|
||||
struct MergeControlFlow : public TF::PerFunctionAggregateAnalysisConsumerPass<
|
||||
MergeControlFlow, TF::SideEffectAnalysis> {
|
||||
void runOnFunction(FuncOp func,
|
||||
const TF::SideEffectAnalysis::Info& side_effect_analysis);
|
||||
};
|
||||
|
||||
// Returns whether it is safe to merge `source` IfRegion into `destination`
|
||||
// IfRegion. `source` must come after `destination`.
|
||||
bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination,
|
||||
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
|
||||
// IfRegion ops must be in the same block.
|
||||
if (source.getOperation()->getBlock() !=
|
||||
destination.getOperation()->getBlock())
|
||||
return false;
|
||||
assert(destination.getOperation()->isBeforeInBlock(source.getOperation()));
|
||||
|
||||
llvm::SmallSetVector<Operation*, 4> source_ops;
|
||||
source_ops.insert(source);
|
||||
for (Operation& op : source.then_branch().front()) {
|
||||
source_ops.insert(&op);
|
||||
}
|
||||
for (Operation& op : source.else_branch().front()) {
|
||||
source_ops.insert(&op);
|
||||
}
|
||||
|
||||
// If there is an intermediate data or side effect dependency between the
|
||||
// ops in destination and the ops in the source, it's not safe to merge
|
||||
// them.
|
||||
llvm::SmallSetVector<Operation*, 4> op_stack;
|
||||
for (auto* user : destination.getOperation()->getUsers()) {
|
||||
if (!source_ops.contains(user)) op_stack.insert(user);
|
||||
}
|
||||
for (Operation& op : destination.then_branch().front()) {
|
||||
for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
|
||||
if (!source_ops.contains(successor)) op_stack.insert(successor);
|
||||
}
|
||||
}
|
||||
for (Operation& op : destination.else_branch().front()) {
|
||||
for (auto* successor : side_effect_analysis.DirectControlSuccessors(&op)) {
|
||||
if (!source_ops.contains(successor)) op_stack.insert(successor);
|
||||
}
|
||||
}
|
||||
|
||||
bool safe_to_merge = true;
|
||||
|
||||
while (!op_stack.empty()) {
|
||||
auto* next_op = op_stack.pop_back_val();
|
||||
for (auto* user : next_op->getUsers()) {
|
||||
if (source_ops.contains(user)) {
|
||||
safe_to_merge = false;
|
||||
break;
|
||||
} else {
|
||||
op_stack.insert(user);
|
||||
}
|
||||
}
|
||||
for (auto* successor :
|
||||
side_effect_analysis.DirectControlSuccessors(next_op)) {
|
||||
if (source_ops.contains(successor)) {
|
||||
safe_to_merge = false;
|
||||
break;
|
||||
} else {
|
||||
op_stack.insert(successor);
|
||||
}
|
||||
}
|
||||
if (!safe_to_merge) break;
|
||||
}
|
||||
return safe_to_merge;
|
||||
}
|
||||
|
||||
// Move the body excluding the terminators of else and then regions from
|
||||
// 'source' to 'destination'.
|
||||
void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
|
||||
Block& destination_then_block = destination.then_branch().front();
|
||||
auto& source_then_body = source.then_branch().front().getOperations();
|
||||
destination_then_block.getOperations().splice(
|
||||
destination_then_block.without_terminator().end(), source_then_body,
|
||||
source_then_body.begin(), std::prev(source_then_body.end()));
|
||||
|
||||
Block& destination_else_block = destination.else_branch().front();
|
||||
auto& source_else_body = source.else_branch().front().getOperations();
|
||||
destination_else_block.getOperations().splice(
|
||||
destination_else_block.without_terminator().end(), source_else_body,
|
||||
source_else_body.begin(), std::prev(source_else_body.end()));
|
||||
}
|
||||
|
||||
Operation* GetIfInsertionPoint(TF::IfRegionOp source,
|
||||
TF::IfRegionOp destination) {
|
||||
// TODO(b/173422484): Pick this insertion point better.
|
||||
return source.getOperation();
|
||||
}
|
||||
|
||||
TF::IfRegionOp CreateMergedIf(TF::IfRegionOp source,
|
||||
TF::IfRegionOp destination) {
|
||||
llvm::SmallVector<Type, 4> merged_return_types;
|
||||
|
||||
OpBuilder builder(destination);
|
||||
// Create new IfRegion with correct merged results.
|
||||
builder.setInsertionPoint(GetIfInsertionPoint(source, destination));
|
||||
auto new_if_op = builder.create<TF::IfRegionOp>(
|
||||
destination.getLoc(), merged_return_types, destination.cond(),
|
||||
destination.is_stateless() && source.is_stateless());
|
||||
new_if_op.then_branch().push_back(new Block);
|
||||
new_if_op.else_branch().push_back(new Block);
|
||||
llvm::SmallVector<Value, 4> merged_then_yield_values;
|
||||
builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
|
||||
builder.create<TF::YieldOp>(
|
||||
destination.then_branch().front().getTerminator()->getLoc(),
|
||||
/*operands=*/merged_then_yield_values);
|
||||
|
||||
llvm::SmallVector<Value, 4> merged_else_yield_values;
|
||||
builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
|
||||
builder.create<TF::YieldOp>(
|
||||
destination.else_branch().front().getTerminator()->getLoc(),
|
||||
/*operands=*/merged_else_yield_values);
|
||||
|
||||
// Merge the two branch regions from both IfRegionOps into new IfRegionOp.
|
||||
MoveBranches(/*source=*/destination, /*destination=*/new_if_op);
|
||||
destination.erase();
|
||||
MoveBranches(/*source=*/source, /*destination=*/new_if_op);
|
||||
source.erase();
|
||||
return new_if_op;
|
||||
}
|
||||
|
||||
// Groups if regions by common predicate and attemps to merge them.
|
||||
void OptimizeIfRegions(
|
||||
Block* block, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
|
||||
// Determine IfRegions with the same predicate.
|
||||
llvm::SmallDenseMap<Value, llvm::SmallVector<TF::IfRegionOp, 8>, 8>
|
||||
grouped_if_ops;
|
||||
block->walk([&](TF::IfRegionOp if_op) {
|
||||
auto it = grouped_if_ops.try_emplace(if_op.cond());
|
||||
it.first->getSecond().push_back(if_op);
|
||||
});
|
||||
|
||||
for (const auto& entry : grouped_if_ops) {
|
||||
llvm::ArrayRef<TF::IfRegionOp> if_ops = entry.second;
|
||||
TF::IfRegionOp first_if_op = if_ops[0];
|
||||
for (int i = 1; i < if_ops.size(); ++i) {
|
||||
TF::IfRegionOp if_op = if_ops[i];
|
||||
if (!SafeToMerge(if_op, first_if_op, side_effect_analysis)) break;
|
||||
|
||||
auto new_if_op = CreateMergedIf(if_op, first_if_op);
|
||||
|
||||
first_if_op = new_if_op;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MergeControlFlow::runOnFunction(
|
||||
FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) {
|
||||
auto result = func.walk([&](tf_device::ClusterOp cluster) {
|
||||
OptimizeIfRegions(&cluster.GetBody(), side_effect_analysis);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (result.wasInterrupted()) return signalPassFailure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass() {
|
||||
return std::make_unique<MergeControlFlow>();
|
||||
}
|
||||
|
||||
static PassRegistration<MergeControlFlow> pass(
|
||||
"tf-merge-control-flow", "Merges control flow with a common predicate.");
|
||||
} // namespace TFDevice
|
||||
} // namespace mlir
|
@ -282,6 +282,9 @@ CreateAnnotateParameterReplicationPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
CreateMarkOpsForOutsideCompilationPass();
|
||||
|
||||
// Creates a pass that merges control flow with similar predicates.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateMergeControlFlowPass();
|
||||
|
||||
// Creates a pass that hoists a `tf_device.launch` body and assigns a `device`
|
||||
// attribute to each TensorFlow dialect op in the body based on the `device`
|
||||
// attribute on the `tf_device.launch`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user