Add support for handling island merging with nested IslandOp/GraphOp in island coarsening pass.

PiperOrigin-RevId: 260783005
This commit is contained in:
Andy Ly 2019-07-30 13:28:54 -07:00 committed by TensorFlower Gardener
parent 5944bf515a
commit 88f5009af7
2 changed files with 115 additions and 35 deletions

View File

@ -228,7 +228,7 @@ func @islands_interleaved(%arg0 : tensor<i32>, %arg1 : tensor<i32>) -> (tensor<i
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
// CHECK-NEXT: %{{[0-9]*}} = "tf.opE"(%[[ARG_0]])
// CHECK-NEXT: tf_executor.yield %[[OP_C]] : tensor<i32>
// CHECK: %{{[0-9]*}}:2 = tf_executor.island {
// CHECK: tf_executor.island {
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_1]])
// CHECK-NEXT: tf_executor.yield %[[OP_F]] : tensor<i32>
// CHECK: tf_executor.fetch %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor<i32>, tensor<i32>
@ -302,7 +302,6 @@ func @merge_islands_only() {
// CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32>
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3]]#1, %[[EXIT]]#1
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ISLAND_3]]#0, %[[CT]]
// CHECK-NEXT: tf_executor.fetch
// Test no merging took place as cycle would be formed otherwise.
@ -327,10 +326,9 @@ func @simple_potential_cycle() {
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND]]#1
// CHECK-NEXT: %{{[0-9]*}}:3 = tf_executor.island(%[[CT]]) {
// CHECK-NEXT: tf_executor.island(%[[CT]]) {
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"
// CHECK-NEXT: tf_executor.yield %[[ISLAND]]#0, %[[OP_B]] : tensor<1xf32>, tensor<1xf32>
// CHECK: tf_executor.fetch
// Test if island was merged into its result.
@ -352,8 +350,87 @@ func @merge_into_result() {
}
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
// CHECK-NEXT: %{{[0-9]*}} = tf_executor.island(%[[CT]]) {
// CHECK-NEXT: tf_executor.island(%[[CT]]) {
// CHECK-NEXT: "tf.opA"
// CHECK-NEXT: "tf.opB"
// CHECK-NEXT: tf_executor.yield
// CHECK: tf_executor.fetch
// Test merging island into data result nested in a graph of another island.
// CHECK-LABEL: func @merge_into_nested_data_result
func @merge_into_nested_data_result() {
tf_executor.graph {
%0:2 = tf_executor.island {
%1 = "tf.opA"() : () -> tensor<1xf32>
tf_executor.yield %1 : tensor<1xf32>
}
%2:2 = tf_executor.island {
%3 = tf_executor.graph {
%4 = tf_executor.ControlTrigger {}
%5:2 = tf_executor.island(%4) {
%6 = "tf.opB"(%0#0) : (tensor<1xf32>) -> tensor<1xf32>
tf_executor.yield %6 : tensor<1xf32>
}
tf_executor.fetch %5#0 : tensor<1xf32>
}
tf_executor.yield %3 : tensor<1xf32>
}
tf_executor.fetch
}
return
}
// CHECK: tf_executor.island {
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
// CHECK-NEXT: [[CT:[0-9]*]] = tf_executor.ControlTrigger
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) {
// CHECK-NEXT: [[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor<1xf32>
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
// CHECK: tf_executor.yield
// Test merging islands in a nested graph.
// CHECK-LABEL: func @merge_islands_inner_graph
func @merge_islands_inner_graph() {
tf_executor.graph {
%0:2 = tf_executor.island {
%1 = "tf.opA"() : () -> tensor<1xf32>
tf_executor.yield %1 : tensor<1xf32>
}
%2:2 = tf_executor.island {
%3 = tf_executor.graph {
%4:2 = tf_executor.island {
%5 = "tf.opB"() : () -> tensor<1xf32>
tf_executor.yield %5 : tensor<1xf32>
}
%6:2 = tf_executor.island {
%7 = "tf.opC"() : () -> tensor<1xf32>
tf_executor.yield %7 : tensor<1xf32>
}
%8:2 = tf_executor.island(%4#1) {
%9 = "tf.opD"(%6#0) : (tensor<1xf32>) -> tensor<1xf32>
tf_executor.yield %9 : tensor<1xf32>
}
tf_executor.fetch %8#0 : tensor<1xf32>
}
tf_executor.yield %3 : tensor<1xf32>
}
tf_executor.fetch
}
return
}
// CHECK: tf_executor.island {
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
// CHECK: tf_executor.island {
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island {
// CHECK-NEXT: "tf.opB"
// CHECK-NEXT: [[OP_C:[0-9]*]] = "tf.opC"
// CHECK-NEXT: [[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]])
// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor<1xf32>
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
// CHECK: tf_executor.yield %[[INNER_GRAPH]] : tensor<1xf32>

View File

@ -32,9 +32,9 @@ limitations under the License.
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/core/platform/logging.h"
namespace mlir {
namespace TFExecutor {
@ -72,26 +72,24 @@ struct ExecutorIslandCoarsening
// the op found is not an island, an empty optional is returned.
llvm::Optional<tf_executor::IslandOp> GetOperandCandidateToMergeWith(
tf_executor::IslandOp* island) {
Operation* graph_op = island->getParentOp();
Operation* candidate = nullptr;
// Check island control operands.
for (Value* island_operand : island->controlInputs()) {
Operation* island_operand_op = island_operand->getDefiningOp();
if (!candidate || candidate->isBeforeInBlock(island_operand_op))
candidate = island_operand_op;
for (Value* input : island->controlInputs()) {
Operation* def = input->getDefiningOp();
DCHECK_EQ(def->getParentOp(), graph_op);
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
}
// Check island data operands.
llvm::SetVector<Value*> inputs;
mlir::getUsedValuesDefinedAbove(island->body(), island->body(), inputs);
for (Value* input : inputs) {
Operation* input_op_def = input->getDefiningOp();
// Input may be a function arg.
if (!input_op_def) continue;
if (!candidate || candidate->isBeforeInBlock(input_op_def))
candidate = input_op_def;
}
island->walk([graph_op, &candidate](Operation* op) {
for (Value* input : op->getOperands()) {
Operation* def = input->getDefiningOp();
if (!def || def->getParentOp() != graph_op) continue;
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
}
});
if (!candidate || !llvm::isa<tf_executor::IslandOp>(candidate))
return llvm::None;
@ -109,16 +107,19 @@ llvm::Optional<tf_executor::IslandOp> GetResultCandidateToMergeWith(
Operation* graph_op = island->getParentOp();
Operation* candidate = nullptr;
// Check island control and data results.
for (Value* result : island->getResults()) {
for (Operation* user : result->getUsers()) {
Operation* user_op = user->getParentOp() == graph_op
? user
: user->getParentOfType<tf_executor::IslandOp>();
if (!user_op) continue;
// Check island control results.
for (Operation* user : island->control()->getUsers()) {
DCHECK_EQ(user->getParentOp(), graph_op);
if (!candidate || candidate->isBeforeInBlock(user)) candidate = user;
}
if (!candidate || user_op->isBeforeInBlock(candidate))
candidate = user_op;
// Check island data results.
Block& graph_block = llvm::cast<tf_executor::GraphOp>(graph_op).GetBody();
for (Value* result : island->outputs()) {
for (Operation* user : result->getUsers()) {
Operation* def = graph_block.findAncestorInstInBlock(*user);
DCHECK_NE(def, nullptr);
if (!candidate || def->isBeforeInBlock(candidate)) candidate = def;
}
}
@ -142,14 +143,16 @@ llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(
// Collects the results for the new island by going through each data output of
// the islands being merged. Unused results outside of the merged island to be
// formed are pruned. Results of the parent island that are consumed by the
// child island are replaced by the respecitve inner ops output from the parent
// formed are pruned. If the child island inner ops consume the parent island
// control output, the child island inner ops will have that respective control
// input pruned. Results of the parent island that are consumed by the child
// island are replaced by the respective inner ops output from the parent
// island.
llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
mlir::MLIRContext* context, tf_executor::IslandOp* parent,
tf_executor::IslandOp* child, llvm::SmallVector<Type, 8>* result_types) {
llvm::SmallVector<Output, 8> results;
Operation* child_op = child->getOperation();
Block& child_body = child->GetBody();
int result_index = 0;
Operation& last_op = parent->GetBody().back();
@ -158,8 +161,7 @@ llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
bool output_captured = false;
Value* yield_input = yield_op.getOperand(result_index);
for (auto& use : llvm::make_early_inc_range(output->getUses())) {
if (use.getOwner()->getParentOfType<tf_executor::IslandOp>() ==
child_op) {
if (child_body.findAncestorInstInBlock(*use.getOwner())) {
// Forward output from inner op.
use.set(yield_input);
} else if (!output_captured) {
@ -333,6 +335,7 @@ void ExecutorIslandCoarsening::runOnFunction() {
}
} while (updated);
});
// TODO(lyandy): Add canonicalization for dedupping control inputs.
}
} // namespace