Add support for handling island merging with nested IslandOp/GraphOp in island coarsening pass.
PiperOrigin-RevId: 260783005
This commit is contained in:
parent
5944bf515a
commit
88f5009af7
@ -228,7 +228,7 @@ func @islands_interleaved(%arg0 : tensor<i32>, %arg1 : tensor<i32>) -> (tensor<i
|
||||
// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]])
|
||||
// CHECK-NEXT: %{{[0-9]*}} = "tf.opE"(%[[ARG_0]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_C]] : tensor<i32>
|
||||
// CHECK: %{{[0-9]*}}:2 = tf_executor.island {
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_1]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_F]] : tensor<i32>
|
||||
// CHECK: tf_executor.fetch %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor<i32>, tensor<i32>
|
||||
@ -302,7 +302,6 @@ func @merge_islands_only() {
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32>
|
||||
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3]]#1, %[[EXIT]]#1
|
||||
// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ISLAND_3]]#0, %[[CT]]
|
||||
// CHECK-NEXT: tf_executor.fetch
|
||||
|
||||
|
||||
// Test no merging took place as cycle would be formed otherwise.
|
||||
@ -327,10 +326,9 @@ func @simple_potential_cycle() {
|
||||
// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
|
||||
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND]]#1
|
||||
// CHECK-NEXT: %{{[0-9]*}}:3 = tf_executor.island(%[[CT]]) {
|
||||
// CHECK-NEXT: tf_executor.island(%[[CT]]) {
|
||||
// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"
|
||||
// CHECK-NEXT: tf_executor.yield %[[ISLAND]]#0, %[[OP_B]] : tensor<1xf32>, tensor<1xf32>
|
||||
// CHECK: tf_executor.fetch
|
||||
|
||||
|
||||
// Test if island was merged into its result.
|
||||
@ -352,8 +350,87 @@ func @merge_into_result() {
|
||||
}
|
||||
|
||||
// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK-NEXT: %{{[0-9]*}} = tf_executor.island(%[[CT]]) {
|
||||
// CHECK-NEXT: tf_executor.island(%[[CT]]) {
|
||||
// CHECK-NEXT: "tf.opA"
|
||||
// CHECK-NEXT: "tf.opB"
|
||||
// CHECK-NEXT: tf_executor.yield
|
||||
// CHECK: tf_executor.fetch
|
||||
|
||||
|
||||
// Test merging island into data result nested in a graph of another island.
|
||||
// CHECK-LABEL: func @merge_into_nested_data_result
|
||||
func @merge_into_nested_data_result() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%1 = "tf.opA"() : () -> tensor<1xf32>
|
||||
tf_executor.yield %1 : tensor<1xf32>
|
||||
}
|
||||
%2:2 = tf_executor.island {
|
||||
%3 = tf_executor.graph {
|
||||
%4 = tf_executor.ControlTrigger {}
|
||||
%5:2 = tf_executor.island(%4) {
|
||||
%6 = "tf.opB"(%0#0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
tf_executor.yield %6 : tensor<1xf32>
|
||||
}
|
||||
tf_executor.fetch %5#0 : tensor<1xf32>
|
||||
}
|
||||
tf_executor.yield %3 : tensor<1xf32>
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
|
||||
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
|
||||
// CHECK-NEXT: [[CT:[0-9]*]] = tf_executor.ControlTrigger
|
||||
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) {
|
||||
// CHECK-NEXT: [[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor<1xf32>
|
||||
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
|
||||
// CHECK: tf_executor.yield
|
||||
|
||||
|
||||
// Test merging islands in a nested graph.
|
||||
// CHECK-LABEL: func @merge_islands_inner_graph
|
||||
func @merge_islands_inner_graph() {
|
||||
tf_executor.graph {
|
||||
%0:2 = tf_executor.island {
|
||||
%1 = "tf.opA"() : () -> tensor<1xf32>
|
||||
tf_executor.yield %1 : tensor<1xf32>
|
||||
}
|
||||
%2:2 = tf_executor.island {
|
||||
%3 = tf_executor.graph {
|
||||
%4:2 = tf_executor.island {
|
||||
%5 = "tf.opB"() : () -> tensor<1xf32>
|
||||
tf_executor.yield %5 : tensor<1xf32>
|
||||
}
|
||||
%6:2 = tf_executor.island {
|
||||
%7 = "tf.opC"() : () -> tensor<1xf32>
|
||||
tf_executor.yield %7 : tensor<1xf32>
|
||||
}
|
||||
%8:2 = tf_executor.island(%4#1) {
|
||||
%9 = "tf.opD"(%6#0) : (tensor<1xf32>) -> tensor<1xf32>
|
||||
tf_executor.yield %9 : tensor<1xf32>
|
||||
}
|
||||
tf_executor.fetch %8#0 : tensor<1xf32>
|
||||
}
|
||||
tf_executor.yield %3 : tensor<1xf32>
|
||||
}
|
||||
tf_executor.fetch
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA"
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32>
|
||||
// CHECK: tf_executor.island {
|
||||
// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph {
|
||||
// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island {
|
||||
// CHECK-NEXT: "tf.opB"
|
||||
// CHECK-NEXT: [[OP_C:[0-9]*]] = "tf.opC"
|
||||
// CHECK-NEXT: [[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]])
|
||||
// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor<1xf32>
|
||||
// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32>
|
||||
// CHECK: tf_executor.yield %[[INNER_GRAPH]] : tensor<1xf32>
|
||||
|
@ -32,9 +32,9 @@ limitations under the License.
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFExecutor {
|
||||
@ -72,26 +72,24 @@ struct ExecutorIslandCoarsening
|
||||
// the op found is not an island, an empty optional is returned.
|
||||
llvm::Optional<tf_executor::IslandOp> GetOperandCandidateToMergeWith(
|
||||
tf_executor::IslandOp* island) {
|
||||
Operation* graph_op = island->getParentOp();
|
||||
Operation* candidate = nullptr;
|
||||
|
||||
// Check island control operands.
|
||||
for (Value* island_operand : island->controlInputs()) {
|
||||
Operation* island_operand_op = island_operand->getDefiningOp();
|
||||
if (!candidate || candidate->isBeforeInBlock(island_operand_op))
|
||||
candidate = island_operand_op;
|
||||
for (Value* input : island->controlInputs()) {
|
||||
Operation* def = input->getDefiningOp();
|
||||
DCHECK_EQ(def->getParentOp(), graph_op);
|
||||
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
|
||||
}
|
||||
|
||||
// Check island data operands.
|
||||
llvm::SetVector<Value*> inputs;
|
||||
mlir::getUsedValuesDefinedAbove(island->body(), island->body(), inputs);
|
||||
for (Value* input : inputs) {
|
||||
Operation* input_op_def = input->getDefiningOp();
|
||||
// Input may be a function arg.
|
||||
if (!input_op_def) continue;
|
||||
|
||||
if (!candidate || candidate->isBeforeInBlock(input_op_def))
|
||||
candidate = input_op_def;
|
||||
}
|
||||
island->walk([graph_op, &candidate](Operation* op) {
|
||||
for (Value* input : op->getOperands()) {
|
||||
Operation* def = input->getDefiningOp();
|
||||
if (!def || def->getParentOp() != graph_op) continue;
|
||||
if (!candidate || candidate->isBeforeInBlock(def)) candidate = def;
|
||||
}
|
||||
});
|
||||
|
||||
if (!candidate || !llvm::isa<tf_executor::IslandOp>(candidate))
|
||||
return llvm::None;
|
||||
@ -109,16 +107,19 @@ llvm::Optional<tf_executor::IslandOp> GetResultCandidateToMergeWith(
|
||||
Operation* graph_op = island->getParentOp();
|
||||
Operation* candidate = nullptr;
|
||||
|
||||
// Check island control and data results.
|
||||
for (Value* result : island->getResults()) {
|
||||
for (Operation* user : result->getUsers()) {
|
||||
Operation* user_op = user->getParentOp() == graph_op
|
||||
? user
|
||||
: user->getParentOfType<tf_executor::IslandOp>();
|
||||
if (!user_op) continue;
|
||||
// Check island control results.
|
||||
for (Operation* user : island->control()->getUsers()) {
|
||||
DCHECK_EQ(user->getParentOp(), graph_op);
|
||||
if (!candidate || candidate->isBeforeInBlock(user)) candidate = user;
|
||||
}
|
||||
|
||||
if (!candidate || user_op->isBeforeInBlock(candidate))
|
||||
candidate = user_op;
|
||||
// Check island data results.
|
||||
Block& graph_block = llvm::cast<tf_executor::GraphOp>(graph_op).GetBody();
|
||||
for (Value* result : island->outputs()) {
|
||||
for (Operation* user : result->getUsers()) {
|
||||
Operation* def = graph_block.findAncestorInstInBlock(*user);
|
||||
DCHECK_NE(def, nullptr);
|
||||
if (!candidate || def->isBeforeInBlock(candidate)) candidate = def;
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,14 +143,16 @@ llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(
|
||||
|
||||
// Collects the results for the new island by going through each data output of
|
||||
// the islands being merged. Unused results outside of the merged island to be
|
||||
// formed are pruned. Results of the parent island that are consumed by the
|
||||
// child island are replaced by the respecitve inner ops output from the parent
|
||||
// formed are pruned. If the child island inner ops consume the parent island
|
||||
// control output, the child island inner ops will have that respective control
|
||||
// input pruned. Results of the parent island that are consumed by the child
|
||||
// island are replaced by the respective inner ops output from the parent
|
||||
// island.
|
||||
llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
|
||||
mlir::MLIRContext* context, tf_executor::IslandOp* parent,
|
||||
tf_executor::IslandOp* child, llvm::SmallVector<Type, 8>* result_types) {
|
||||
llvm::SmallVector<Output, 8> results;
|
||||
Operation* child_op = child->getOperation();
|
||||
Block& child_body = child->GetBody();
|
||||
int result_index = 0;
|
||||
|
||||
Operation& last_op = parent->GetBody().back();
|
||||
@ -158,8 +161,7 @@ llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
|
||||
bool output_captured = false;
|
||||
Value* yield_input = yield_op.getOperand(result_index);
|
||||
for (auto& use : llvm::make_early_inc_range(output->getUses())) {
|
||||
if (use.getOwner()->getParentOfType<tf_executor::IslandOp>() ==
|
||||
child_op) {
|
||||
if (child_body.findAncestorInstInBlock(*use.getOwner())) {
|
||||
// Forward output from inner op.
|
||||
use.set(yield_input);
|
||||
} else if (!output_captured) {
|
||||
@ -333,6 +335,7 @@ void ExecutorIslandCoarsening::runOnFunction() {
|
||||
}
|
||||
} while (updated);
|
||||
});
|
||||
// TODO(lyandy): Add canonicalization for dedupping control inputs.
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user