Cleanup TensorFlow executor dialect island coarsening pass to match TPU cluster formation pass style more (NFC).
Updated functions to not return an output and have output parameters. Replaced `output` with `result`. Switched from storing island type and result index to results directly. PiperOrigin-RevId: 266509565
This commit is contained in:
parent
5025675366
commit
98bd5a6e22
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This transformation pass takes TFExecutor dialect IslandOps and merges them.
|
||||
// Note, this currently does not handle TensorFlow V1 style control flow/frames
|
||||
// or side effecting ops yet.
|
||||
// This transformation pass takes TensorFlow executor dialect IslandOps and
|
||||
// merges them. Note, this currently does not handle TensorFlow V1 style control
|
||||
// flow/frames or side effecting ops yet.
|
||||
|
||||
#include <iterator>
|
||||
#include <tuple>
|
||||
@ -46,25 +46,19 @@ namespace {
|
||||
// merging another island or is the island (child) being being merged.
|
||||
enum IslandType { kParentIsland, kChildIsland };
|
||||
|
||||
// Output is a helper struct holding a result index and island type (parent or
|
||||
// child).
|
||||
struct Output {
|
||||
Output(IslandType island_type, int result_index)
|
||||
: island_type(island_type), result_index(result_index) {}
|
||||
// IslandResult is a helper struct holding an islands result and associated
|
||||
// inner op result.
|
||||
struct IslandResult {
|
||||
IslandResult(Value* inner_op_result, Value* island_result)
|
||||
: inner_op_result(inner_op_result), island_result(island_result) {}
|
||||
|
||||
IslandType island_type;
|
||||
int result_index;
|
||||
Value* inner_op_result;
|
||||
Value* island_result;
|
||||
};
|
||||
|
||||
struct ExecutorIslandCoarsening
|
||||
: public FunctionPass<ExecutorIslandCoarsening> {
|
||||
void runOnFunction() override;
|
||||
|
||||
private:
|
||||
void MergeIslands(IslandOp parent, IslandOp child,
|
||||
IslandType insert_position);
|
||||
bool MergeIslandWithOperand(IslandOp child);
|
||||
bool MergeIslandWithResult(IslandOp parent);
|
||||
};
|
||||
|
||||
// Finds the operation leading to an island that the island can be merged with.
|
||||
@ -97,7 +91,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
||||
}
|
||||
|
||||
// Finds the operation leading from an island that the island can be merged
|
||||
// with. This looks for the operation, either control output or data output to
|
||||
// with. This looks for the operation, either control result or data result to
|
||||
// an op, that is closest to the island in the graph. If no candidate can be
|
||||
// found or the op found is not an island, an empty optional is returned.
|
||||
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
||||
@ -136,54 +130,60 @@ llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(IslandOp parent,
|
||||
return operands;
|
||||
}
|
||||
|
||||
// Collects the results for the new island by going through each data output of
|
||||
// Collects the results for the new island by going through each data result of
|
||||
// the islands being merged. Unused results outside of the merged island to be
|
||||
// formed are pruned. If the child island inner ops consume the parent island
|
||||
// control output, the child island inner ops will have that respective control
|
||||
// control result, the child island inner ops will have that respective control
|
||||
// input pruned. Results of the parent island that are consumed by the child
|
||||
// island are replaced by the respective inner ops output from the parent
|
||||
// island are replaced by the respective inner ops result from the parent
|
||||
// island.
|
||||
llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
|
||||
mlir::MLIRContext* context, IslandOp parent, IslandOp child,
|
||||
llvm::SmallVector<Type, 8>* result_types) {
|
||||
llvm::SmallVector<Output, 8> results;
|
||||
llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
|
||||
IslandOp parent, IslandOp child) {
|
||||
llvm::SmallVector<IslandResult, 8> results;
|
||||
|
||||
YieldOp yield_op = parent.GetYield();
|
||||
Block& child_body = child.GetBody();
|
||||
for (auto& ret_and_idx : llvm::enumerate(parent.outputs())) {
|
||||
bool output_captured = false;
|
||||
Value* yield_input = yield_op.getOperand(ret_and_idx.index());
|
||||
for (auto& use :
|
||||
llvm::make_early_inc_range(ret_and_idx.value()->getUses())) {
|
||||
for (auto ret_vals :
|
||||
llvm::zip(parent.GetYield().getOperands(), parent.outputs())) {
|
||||
bool result_captured = false;
|
||||
Value* inner_op_result = std::get<0>(ret_vals);
|
||||
Value* island_result = std::get<1>(ret_vals);
|
||||
for (auto& use : llvm::make_early_inc_range(island_result->getUses())) {
|
||||
if (child_body.findAncestorInstInBlock(*use.getOwner())) {
|
||||
// Forward output from inner op.
|
||||
use.set(yield_input);
|
||||
} else if (!output_captured) {
|
||||
results.push_back(
|
||||
Output(IslandType::kParentIsland, ret_and_idx.index()));
|
||||
result_types->push_back(ret_and_idx.value()->getType());
|
||||
output_captured = true;
|
||||
// Forward result from inner op.
|
||||
use.set(inner_op_result);
|
||||
} else if (!result_captured) {
|
||||
results.emplace_back(inner_op_result, island_result);
|
||||
result_captured = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& ret_and_idx : llvm::enumerate(child.outputs())) {
|
||||
if (!ret_and_idx.value()->use_empty()) {
|
||||
results.push_back(Output(IslandType::kChildIsland, ret_and_idx.index()));
|
||||
result_types->push_back(ret_and_idx.value()->getType());
|
||||
for (auto ret_vals :
|
||||
llvm::zip(child.GetYield().getOperands(), child.outputs())) {
|
||||
Value* inner_op_result = std::get<0>(ret_vals);
|
||||
Value* island_result = std::get<1>(ret_vals);
|
||||
if (!island_result->use_empty()) {
|
||||
results.emplace_back(inner_op_result, island_result);
|
||||
}
|
||||
}
|
||||
|
||||
// IslandOps always have a control output.
|
||||
result_types->push_back(ControlType::get(context));
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
// Creates the new merged island.
|
||||
IslandOp CreateNewIsland(Operation* old_island,
|
||||
llvm::ArrayRef<Type> result_types,
|
||||
llvm::ArrayRef<Value*> operands) {
|
||||
IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
|
||||
IslandType insert_position,
|
||||
llvm::ArrayRef<Value*> operands,
|
||||
llvm::ArrayRef<IslandResult> results) {
|
||||
// Collect types from results.
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
for (const auto& result : results)
|
||||
result_types.push_back(result.inner_op_result->getType());
|
||||
|
||||
// IslandOps always have a control result.
|
||||
result_types.push_back(ControlType::get(parent.getContext()));
|
||||
|
||||
Operation* old_island = insert_position == kParentIsland ? parent : child;
|
||||
OpBuilder builder(old_island);
|
||||
auto new_island = builder.create<IslandOp>(
|
||||
old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{});
|
||||
@ -193,22 +193,18 @@ IslandOp CreateNewIsland(Operation* old_island,
|
||||
|
||||
// Creates respective YieldOp for the new merged island.
|
||||
YieldOp CreateNewIslandYieldOp(IslandOp new_island,
|
||||
llvm::ArrayRef<Output> results, IslandOp parent,
|
||||
IslandOp child) {
|
||||
llvm::ArrayRef<IslandResult> results) {
|
||||
llvm::SmallVector<Value*, 8> yield_operands;
|
||||
yield_operands.reserve(results.size());
|
||||
|
||||
for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
|
||||
// Get consumed output (island type and result index).
|
||||
const auto& output = std::get<0>(ret_vals);
|
||||
IslandOp& output_island =
|
||||
output.island_type == IslandType::kParentIsland ? parent : child;
|
||||
Value* result = output_island.getResult(output.result_index);
|
||||
// Replace original result with new island result.
|
||||
result->replaceAllUsesWith(std::get<1>(ret_vals));
|
||||
// Find YieldOp in original island, grab the associated operand (inner op
|
||||
// output) and add it as a operand to the YieldOp of the merged island.
|
||||
yield_operands.push_back(
|
||||
output_island.GetYield().getOperand(output.result_index));
|
||||
const auto& old_result = std::get<0>(ret_vals);
|
||||
|
||||
// Replace original island result with new island result.
|
||||
old_result.island_result->replaceAllUsesWith(std::get<1>(ret_vals));
|
||||
|
||||
// Add associated inner op result to operands of the YieldOp.
|
||||
yield_operands.push_back(old_result.inner_op_result);
|
||||
}
|
||||
|
||||
// Create YieldOp for the new island.
|
||||
@ -234,25 +230,21 @@ void MoveInnerOpsToNewIsland(IslandOp parent, IslandOp child,
|
||||
}
|
||||
|
||||
// Merges two islands and places new merged island before parent or child.
|
||||
void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child,
|
||||
IslandType insert_position) {
|
||||
void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) {
|
||||
// Collect operands for the new merged island.
|
||||
llvm::SmallSetVector<Value*, 8> operands =
|
||||
GetNewIslandOperands(parent, child);
|
||||
|
||||
// Collect results and result types for the new merged island.
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
llvm::SmallVector<Output, 8> results = GetNewIslandResultsAndForwardOutputs(
|
||||
&getContext(), parent, child, &result_types);
|
||||
// Collect results for the new merged island.
|
||||
llvm::SmallVector<IslandResult, 8> results =
|
||||
GetNewIslandResultsAndForwardResults(parent, child);
|
||||
|
||||
// Create the new merged island.
|
||||
IslandOp new_island = CreateNewIsland(
|
||||
insert_position == IslandType::kParentIsland ? parent : child,
|
||||
result_types, operands.getArrayRef());
|
||||
IslandOp new_island = CreateNewIsland(parent, child, insert_position,
|
||||
operands.getArrayRef(), results);
|
||||
|
||||
// Create associated YieldOp for the new merged island.
|
||||
YieldOp new_yield_op =
|
||||
CreateNewIslandYieldOp(new_island, results, parent, child);
|
||||
YieldOp new_yield_op = CreateNewIslandYieldOp(new_island, results);
|
||||
|
||||
// Move inner ops from original islands into the new island.
|
||||
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
|
||||
@ -270,12 +262,11 @@ void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child,
|
||||
// operand must be another IslandOp for merging to take place. A new island is
|
||||
// created and the islands being merged are removed if a merge took place.
|
||||
// Returns true if the island was merged with its operand.
|
||||
bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) {
|
||||
bool MergeIslandWithOperand(IslandOp child) {
|
||||
// Find candidate operand to merge island with.
|
||||
llvm::Optional<IslandOp> candidate = GetOperandCandidateToMergeWith(child);
|
||||
if (!candidate.hasValue()) return false;
|
||||
auto& parent = candidate.getValue();
|
||||
MergeIslands(parent, child, IslandType::kParentIsland);
|
||||
MergeIslands(candidate.getValue(), child, kParentIsland);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -283,17 +274,16 @@ bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) {
|
||||
// must be another IslandOp for merging to take place. A new island is created
|
||||
// and the islands being merged are removed if a merge took place. Returns true
|
||||
// if the island was merged with its result.
|
||||
bool ExecutorIslandCoarsening::MergeIslandWithResult(IslandOp parent) {
|
||||
bool MergeIslandWithResult(IslandOp parent) {
|
||||
// Find candidate result to merge island with.
|
||||
llvm::Optional<IslandOp> candidate = GetResultCandidateToMergeWith(parent);
|
||||
if (!candidate.hasValue()) return false;
|
||||
auto& child = candidate.getValue();
|
||||
MergeIslands(parent, child, IslandType::kChildIsland);
|
||||
return false;
|
||||
MergeIslands(parent, candidate.getValue(), kChildIsland);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ExecutorIslandCoarsening::runOnFunction() {
|
||||
getFunction().walk([this](GraphOp graph) {
|
||||
getFunction().walk([](GraphOp graph) {
|
||||
Block& graph_body = graph.GetBody();
|
||||
|
||||
bool updated = false;
|
||||
@ -323,7 +313,8 @@ std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass() {
|
||||
}
|
||||
|
||||
static PassRegistration<ExecutorIslandCoarsening> pass(
|
||||
"tf-executor-island-coarsening", "Merges TFExecutor dialect IslandOps");
|
||||
"tf-executor-island-coarsening",
|
||||
"Merges TensorFlow executor dialect IslandOps");
|
||||
|
||||
} // namespace tf_executor
|
||||
} // namespace mlir
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user