Cleanup TensorFlow executor dialect island coarsening pass to match TPU cluster formation pass style more (NFC).
Updated functions to not return an output and have output parameters. Replaced `output` with `result`. Switched from storing island type and result index to results directly. PiperOrigin-RevId: 266509565
This commit is contained in:
parent
5025675366
commit
98bd5a6e22
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// This transformation pass takes TFExecutor dialect IslandOps and merges them.
|
// This transformation pass takes TensorFlow executor dialect IslandOps and
|
||||||
// Note, this currently does not handle TensorFlow V1 style control flow/frames
|
// merges them. Note, this currently does not handle TensorFlow V1 style control
|
||||||
// or side effecting ops yet.
|
// flow/frames or side effecting ops yet.
|
||||||
|
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
@ -46,25 +46,19 @@ namespace {
|
|||||||
// merging another island or is the island (child) being being merged.
|
// merging another island or is the island (child) being being merged.
|
||||||
enum IslandType { kParentIsland, kChildIsland };
|
enum IslandType { kParentIsland, kChildIsland };
|
||||||
|
|
||||||
// Output is a helper struct holding a result index and island type (parent or
|
// IslandResult is a helper struct holding an islands result and associated
|
||||||
// child).
|
// inner op result.
|
||||||
struct Output {
|
struct IslandResult {
|
||||||
Output(IslandType island_type, int result_index)
|
IslandResult(Value* inner_op_result, Value* island_result)
|
||||||
: island_type(island_type), result_index(result_index) {}
|
: inner_op_result(inner_op_result), island_result(island_result) {}
|
||||||
|
|
||||||
IslandType island_type;
|
Value* inner_op_result;
|
||||||
int result_index;
|
Value* island_result;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ExecutorIslandCoarsening
|
struct ExecutorIslandCoarsening
|
||||||
: public FunctionPass<ExecutorIslandCoarsening> {
|
: public FunctionPass<ExecutorIslandCoarsening> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
private:
|
|
||||||
void MergeIslands(IslandOp parent, IslandOp child,
|
|
||||||
IslandType insert_position);
|
|
||||||
bool MergeIslandWithOperand(IslandOp child);
|
|
||||||
bool MergeIslandWithResult(IslandOp parent);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Finds the operation leading to an island that the island can be merged with.
|
// Finds the operation leading to an island that the island can be merged with.
|
||||||
@ -97,7 +91,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Finds the operation leading from an island that the island can be merged
|
// Finds the operation leading from an island that the island can be merged
|
||||||
// with. This looks for the operation, either control output or data output to
|
// with. This looks for the operation, either control result or data result to
|
||||||
// an op, that is closest to the island in the graph. If no candidate can be
|
// an op, that is closest to the island in the graph. If no candidate can be
|
||||||
// found or the op found is not an island, an empty optional is returned.
|
// found or the op found is not an island, an empty optional is returned.
|
||||||
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
|
||||||
@ -136,54 +130,60 @@ llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(IslandOp parent,
|
|||||||
return operands;
|
return operands;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collects the results for the new island by going through each data output of
|
// Collects the results for the new island by going through each data result of
|
||||||
// the islands being merged. Unused results outside of the merged island to be
|
// the islands being merged. Unused results outside of the merged island to be
|
||||||
// formed are pruned. If the child island inner ops consume the parent island
|
// formed are pruned. If the child island inner ops consume the parent island
|
||||||
// control output, the child island inner ops will have that respective control
|
// control result, the child island inner ops will have that respective control
|
||||||
// input pruned. Results of the parent island that are consumed by the child
|
// input pruned. Results of the parent island that are consumed by the child
|
||||||
// island are replaced by the respective inner ops output from the parent
|
// island are replaced by the respective inner ops result from the parent
|
||||||
// island.
|
// island.
|
||||||
llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
|
llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
|
||||||
mlir::MLIRContext* context, IslandOp parent, IslandOp child,
|
IslandOp parent, IslandOp child) {
|
||||||
llvm::SmallVector<Type, 8>* result_types) {
|
llvm::SmallVector<IslandResult, 8> results;
|
||||||
llvm::SmallVector<Output, 8> results;
|
|
||||||
|
|
||||||
YieldOp yield_op = parent.GetYield();
|
|
||||||
Block& child_body = child.GetBody();
|
Block& child_body = child.GetBody();
|
||||||
for (auto& ret_and_idx : llvm::enumerate(parent.outputs())) {
|
for (auto ret_vals :
|
||||||
bool output_captured = false;
|
llvm::zip(parent.GetYield().getOperands(), parent.outputs())) {
|
||||||
Value* yield_input = yield_op.getOperand(ret_and_idx.index());
|
bool result_captured = false;
|
||||||
for (auto& use :
|
Value* inner_op_result = std::get<0>(ret_vals);
|
||||||
llvm::make_early_inc_range(ret_and_idx.value()->getUses())) {
|
Value* island_result = std::get<1>(ret_vals);
|
||||||
|
for (auto& use : llvm::make_early_inc_range(island_result->getUses())) {
|
||||||
if (child_body.findAncestorInstInBlock(*use.getOwner())) {
|
if (child_body.findAncestorInstInBlock(*use.getOwner())) {
|
||||||
// Forward output from inner op.
|
// Forward result from inner op.
|
||||||
use.set(yield_input);
|
use.set(inner_op_result);
|
||||||
} else if (!output_captured) {
|
} else if (!result_captured) {
|
||||||
results.push_back(
|
results.emplace_back(inner_op_result, island_result);
|
||||||
Output(IslandType::kParentIsland, ret_and_idx.index()));
|
result_captured = true;
|
||||||
result_types->push_back(ret_and_idx.value()->getType());
|
|
||||||
output_captured = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& ret_and_idx : llvm::enumerate(child.outputs())) {
|
for (auto ret_vals :
|
||||||
if (!ret_and_idx.value()->use_empty()) {
|
llvm::zip(child.GetYield().getOperands(), child.outputs())) {
|
||||||
results.push_back(Output(IslandType::kChildIsland, ret_and_idx.index()));
|
Value* inner_op_result = std::get<0>(ret_vals);
|
||||||
result_types->push_back(ret_and_idx.value()->getType());
|
Value* island_result = std::get<1>(ret_vals);
|
||||||
|
if (!island_result->use_empty()) {
|
||||||
|
results.emplace_back(inner_op_result, island_result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IslandOps always have a control output.
|
|
||||||
result_types->push_back(ControlType::get(context));
|
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates the new merged island.
|
// Creates the new merged island.
|
||||||
IslandOp CreateNewIsland(Operation* old_island,
|
IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
|
||||||
llvm::ArrayRef<Type> result_types,
|
IslandType insert_position,
|
||||||
llvm::ArrayRef<Value*> operands) {
|
llvm::ArrayRef<Value*> operands,
|
||||||
|
llvm::ArrayRef<IslandResult> results) {
|
||||||
|
// Collect types from results.
|
||||||
|
llvm::SmallVector<Type, 8> result_types;
|
||||||
|
for (const auto& result : results)
|
||||||
|
result_types.push_back(result.inner_op_result->getType());
|
||||||
|
|
||||||
|
// IslandOps always have a control result.
|
||||||
|
result_types.push_back(ControlType::get(parent.getContext()));
|
||||||
|
|
||||||
|
Operation* old_island = insert_position == kParentIsland ? parent : child;
|
||||||
OpBuilder builder(old_island);
|
OpBuilder builder(old_island);
|
||||||
auto new_island = builder.create<IslandOp>(
|
auto new_island = builder.create<IslandOp>(
|
||||||
old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{});
|
old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{});
|
||||||
@ -193,22 +193,18 @@ IslandOp CreateNewIsland(Operation* old_island,
|
|||||||
|
|
||||||
// Creates respective YieldOp for the new merged island.
|
// Creates respective YieldOp for the new merged island.
|
||||||
YieldOp CreateNewIslandYieldOp(IslandOp new_island,
|
YieldOp CreateNewIslandYieldOp(IslandOp new_island,
|
||||||
llvm::ArrayRef<Output> results, IslandOp parent,
|
llvm::ArrayRef<IslandResult> results) {
|
||||||
IslandOp child) {
|
|
||||||
llvm::SmallVector<Value*, 8> yield_operands;
|
llvm::SmallVector<Value*, 8> yield_operands;
|
||||||
yield_operands.reserve(results.size());
|
yield_operands.reserve(results.size());
|
||||||
|
|
||||||
for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
|
for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
|
||||||
// Get consumed output (island type and result index).
|
const auto& old_result = std::get<0>(ret_vals);
|
||||||
const auto& output = std::get<0>(ret_vals);
|
|
||||||
IslandOp& output_island =
|
// Replace original island result with new island result.
|
||||||
output.island_type == IslandType::kParentIsland ? parent : child;
|
old_result.island_result->replaceAllUsesWith(std::get<1>(ret_vals));
|
||||||
Value* result = output_island.getResult(output.result_index);
|
|
||||||
// Replace original result with new island result.
|
// Add associated inner op result to operands of the YieldOp.
|
||||||
result->replaceAllUsesWith(std::get<1>(ret_vals));
|
yield_operands.push_back(old_result.inner_op_result);
|
||||||
// Find YieldOp in original island, grab the associated operand (inner op
|
|
||||||
// output) and add it as a operand to the YieldOp of the merged island.
|
|
||||||
yield_operands.push_back(
|
|
||||||
output_island.GetYield().getOperand(output.result_index));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create YieldOp for the new island.
|
// Create YieldOp for the new island.
|
||||||
@ -234,25 +230,21 @@ void MoveInnerOpsToNewIsland(IslandOp parent, IslandOp child,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Merges two islands and places new merged island before parent or child.
|
// Merges two islands and places new merged island before parent or child.
|
||||||
void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child,
|
void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) {
|
||||||
IslandType insert_position) {
|
|
||||||
// Collect operands for the new merged island.
|
// Collect operands for the new merged island.
|
||||||
llvm::SmallSetVector<Value*, 8> operands =
|
llvm::SmallSetVector<Value*, 8> operands =
|
||||||
GetNewIslandOperands(parent, child);
|
GetNewIslandOperands(parent, child);
|
||||||
|
|
||||||
// Collect results and result types for the new merged island.
|
// Collect results for the new merged island.
|
||||||
llvm::SmallVector<Type, 8> result_types;
|
llvm::SmallVector<IslandResult, 8> results =
|
||||||
llvm::SmallVector<Output, 8> results = GetNewIslandResultsAndForwardOutputs(
|
GetNewIslandResultsAndForwardResults(parent, child);
|
||||||
&getContext(), parent, child, &result_types);
|
|
||||||
|
|
||||||
// Create the new merged island.
|
// Create the new merged island.
|
||||||
IslandOp new_island = CreateNewIsland(
|
IslandOp new_island = CreateNewIsland(parent, child, insert_position,
|
||||||
insert_position == IslandType::kParentIsland ? parent : child,
|
operands.getArrayRef(), results);
|
||||||
result_types, operands.getArrayRef());
|
|
||||||
|
|
||||||
// Create associated YieldOp for the new merged island.
|
// Create associated YieldOp for the new merged island.
|
||||||
YieldOp new_yield_op =
|
YieldOp new_yield_op = CreateNewIslandYieldOp(new_island, results);
|
||||||
CreateNewIslandYieldOp(new_island, results, parent, child);
|
|
||||||
|
|
||||||
// Move inner ops from original islands into the new island.
|
// Move inner ops from original islands into the new island.
|
||||||
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
|
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
|
||||||
@ -270,12 +262,11 @@ void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child,
|
|||||||
// operand must be another IslandOp for merging to take place. A new island is
|
// operand must be another IslandOp for merging to take place. A new island is
|
||||||
// created and the islands being merged are removed if a merge took place.
|
// created and the islands being merged are removed if a merge took place.
|
||||||
// Returns true if the island was merged with its operand.
|
// Returns true if the island was merged with its operand.
|
||||||
bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) {
|
bool MergeIslandWithOperand(IslandOp child) {
|
||||||
// Find candidate operand to merge island with.
|
// Find candidate operand to merge island with.
|
||||||
llvm::Optional<IslandOp> candidate = GetOperandCandidateToMergeWith(child);
|
llvm::Optional<IslandOp> candidate = GetOperandCandidateToMergeWith(child);
|
||||||
if (!candidate.hasValue()) return false;
|
if (!candidate.hasValue()) return false;
|
||||||
auto& parent = candidate.getValue();
|
MergeIslands(candidate.getValue(), child, kParentIsland);
|
||||||
MergeIslands(parent, child, IslandType::kParentIsland);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -283,17 +274,16 @@ bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) {
|
|||||||
// must be another IslandOp for merging to take place. A new island is created
|
// must be another IslandOp for merging to take place. A new island is created
|
||||||
// and the islands being merged are removed if a merge took place. Returns true
|
// and the islands being merged are removed if a merge took place. Returns true
|
||||||
// if the island was merged with its result.
|
// if the island was merged with its result.
|
||||||
bool ExecutorIslandCoarsening::MergeIslandWithResult(IslandOp parent) {
|
bool MergeIslandWithResult(IslandOp parent) {
|
||||||
// Find candidate result to merge island with.
|
// Find candidate result to merge island with.
|
||||||
llvm::Optional<IslandOp> candidate = GetResultCandidateToMergeWith(parent);
|
llvm::Optional<IslandOp> candidate = GetResultCandidateToMergeWith(parent);
|
||||||
if (!candidate.hasValue()) return false;
|
if (!candidate.hasValue()) return false;
|
||||||
auto& child = candidate.getValue();
|
MergeIslands(parent, candidate.getValue(), kChildIsland);
|
||||||
MergeIslands(parent, child, IslandType::kChildIsland);
|
return true;
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExecutorIslandCoarsening::runOnFunction() {
|
void ExecutorIslandCoarsening::runOnFunction() {
|
||||||
getFunction().walk([this](GraphOp graph) {
|
getFunction().walk([](GraphOp graph) {
|
||||||
Block& graph_body = graph.GetBody();
|
Block& graph_body = graph.GetBody();
|
||||||
|
|
||||||
bool updated = false;
|
bool updated = false;
|
||||||
@ -323,7 +313,8 @@ std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<ExecutorIslandCoarsening> pass(
|
static PassRegistration<ExecutorIslandCoarsening> pass(
|
||||||
"tf-executor-island-coarsening", "Merges TFExecutor dialect IslandOps");
|
"tf-executor-island-coarsening",
|
||||||
|
"Merges TensorFlow executor dialect IslandOps");
|
||||||
|
|
||||||
} // namespace tf_executor
|
} // namespace tf_executor
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user