Cleanup TensorFlow executor dialect island coarsening pass to match TPU cluster formation pass style more (NFC).

Updated functions to not return an output and have output parameters. Replaced `output` with `result`. Switched from storing island type and result index to results directly.

PiperOrigin-RevId: 266509565
This commit is contained in:
Andy Ly 2019-08-30 21:51:06 -07:00 committed by TensorFlower Gardener
parent 5025675366
commit 98bd5a6e22

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// This transformation pass takes TFExecutor dialect IslandOps and merges them. // This transformation pass takes TensorFlow executor dialect IslandOps and
// Note, this currently does not handle TensorFlow V1 style control flow/frames // merges them. Note, this currently does not handle TensorFlow V1 style control
// or side effecting ops yet. // flow/frames or side effecting ops yet.
#include <iterator> #include <iterator>
#include <tuple> #include <tuple>
@ -46,25 +46,19 @@ namespace {
// merging another island or is the island (child) being being merged. // merging another island or is the island (child) being being merged.
enum IslandType { kParentIsland, kChildIsland }; enum IslandType { kParentIsland, kChildIsland };
// Output is a helper struct holding a result index and island type (parent or // IslandResult is a helper struct holding an islands result and associated
// child). // inner op result.
struct Output { struct IslandResult {
Output(IslandType island_type, int result_index) IslandResult(Value* inner_op_result, Value* island_result)
: island_type(island_type), result_index(result_index) {} : inner_op_result(inner_op_result), island_result(island_result) {}
IslandType island_type; Value* inner_op_result;
int result_index; Value* island_result;
}; };
struct ExecutorIslandCoarsening struct ExecutorIslandCoarsening
: public FunctionPass<ExecutorIslandCoarsening> { : public FunctionPass<ExecutorIslandCoarsening> {
void runOnFunction() override; void runOnFunction() override;
private:
void MergeIslands(IslandOp parent, IslandOp child,
IslandType insert_position);
bool MergeIslandWithOperand(IslandOp child);
bool MergeIslandWithResult(IslandOp parent);
}; };
// Finds the operation leading to an island that the island can be merged with. // Finds the operation leading to an island that the island can be merged with.
@ -97,7 +91,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
} }
// Finds the operation leading from an island that the island can be merged // Finds the operation leading from an island that the island can be merged
// with. This looks for the operation, either control output or data output to // with. This looks for the operation, either control result or data result to
// an op, that is closest to the island in the graph. If no candidate can be // an op, that is closest to the island in the graph. If no candidate can be
// found or the op found is not an island, an empty optional is returned. // found or the op found is not an island, an empty optional is returned.
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) { llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
@ -136,54 +130,60 @@ llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(IslandOp parent,
return operands; return operands;
} }
// Collects the results for the new island by going through each data output of // Collects the results for the new island by going through each data result of
// the islands being merged. Unused results outside of the merged island to be // the islands being merged. Unused results outside of the merged island to be
// formed are pruned. If the child island inner ops consume the parent island // formed are pruned. If the child island inner ops consume the parent island
// control output, the child island inner ops will have that respective control // control result, the child island inner ops will have that respective control
// input pruned. Results of the parent island that are consumed by the child // input pruned. Results of the parent island that are consumed by the child
// island are replaced by the respective inner ops output from the parent // island are replaced by the respective inner ops result from the parent
// island. // island.
llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs( llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
mlir::MLIRContext* context, IslandOp parent, IslandOp child, IslandOp parent, IslandOp child) {
llvm::SmallVector<Type, 8>* result_types) { llvm::SmallVector<IslandResult, 8> results;
llvm::SmallVector<Output, 8> results;
YieldOp yield_op = parent.GetYield();
Block& child_body = child.GetBody(); Block& child_body = child.GetBody();
for (auto& ret_and_idx : llvm::enumerate(parent.outputs())) { for (auto ret_vals :
bool output_captured = false; llvm::zip(parent.GetYield().getOperands(), parent.outputs())) {
Value* yield_input = yield_op.getOperand(ret_and_idx.index()); bool result_captured = false;
for (auto& use : Value* inner_op_result = std::get<0>(ret_vals);
llvm::make_early_inc_range(ret_and_idx.value()->getUses())) { Value* island_result = std::get<1>(ret_vals);
for (auto& use : llvm::make_early_inc_range(island_result->getUses())) {
if (child_body.findAncestorInstInBlock(*use.getOwner())) { if (child_body.findAncestorInstInBlock(*use.getOwner())) {
// Forward output from inner op. // Forward result from inner op.
use.set(yield_input); use.set(inner_op_result);
} else if (!output_captured) { } else if (!result_captured) {
results.push_back( results.emplace_back(inner_op_result, island_result);
Output(IslandType::kParentIsland, ret_and_idx.index())); result_captured = true;
result_types->push_back(ret_and_idx.value()->getType());
output_captured = true;
} }
} }
} }
for (auto& ret_and_idx : llvm::enumerate(child.outputs())) { for (auto ret_vals :
if (!ret_and_idx.value()->use_empty()) { llvm::zip(child.GetYield().getOperands(), child.outputs())) {
results.push_back(Output(IslandType::kChildIsland, ret_and_idx.index())); Value* inner_op_result = std::get<0>(ret_vals);
result_types->push_back(ret_and_idx.value()->getType()); Value* island_result = std::get<1>(ret_vals);
if (!island_result->use_empty()) {
results.emplace_back(inner_op_result, island_result);
} }
} }
// IslandOps always have a control output.
result_types->push_back(ControlType::get(context));
return results; return results;
} }
// Creates the new merged island. // Creates the new merged island.
IslandOp CreateNewIsland(Operation* old_island, IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
llvm::ArrayRef<Type> result_types, IslandType insert_position,
llvm::ArrayRef<Value*> operands) { llvm::ArrayRef<Value*> operands,
llvm::ArrayRef<IslandResult> results) {
// Collect types from results.
llvm::SmallVector<Type, 8> result_types;
for (const auto& result : results)
result_types.push_back(result.inner_op_result->getType());
// IslandOps always have a control result.
result_types.push_back(ControlType::get(parent.getContext()));
Operation* old_island = insert_position == kParentIsland ? parent : child;
OpBuilder builder(old_island); OpBuilder builder(old_island);
auto new_island = builder.create<IslandOp>( auto new_island = builder.create<IslandOp>(
old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{}); old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{});
@ -193,22 +193,18 @@ IslandOp CreateNewIsland(Operation* old_island,
// Creates respective YieldOp for the new merged island. // Creates respective YieldOp for the new merged island.
YieldOp CreateNewIslandYieldOp(IslandOp new_island, YieldOp CreateNewIslandYieldOp(IslandOp new_island,
llvm::ArrayRef<Output> results, IslandOp parent, llvm::ArrayRef<IslandResult> results) {
IslandOp child) {
llvm::SmallVector<Value*, 8> yield_operands; llvm::SmallVector<Value*, 8> yield_operands;
yield_operands.reserve(results.size()); yield_operands.reserve(results.size());
for (auto ret_vals : llvm::zip(results, new_island.outputs())) { for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
// Get consumed output (island type and result index). const auto& old_result = std::get<0>(ret_vals);
const auto& output = std::get<0>(ret_vals);
IslandOp& output_island = // Replace original island result with new island result.
output.island_type == IslandType::kParentIsland ? parent : child; old_result.island_result->replaceAllUsesWith(std::get<1>(ret_vals));
Value* result = output_island.getResult(output.result_index);
// Replace original result with new island result. // Add associated inner op result to operands of the YieldOp.
result->replaceAllUsesWith(std::get<1>(ret_vals)); yield_operands.push_back(old_result.inner_op_result);
// Find YieldOp in original island, grab the associated operand (inner op
// output) and add it as a operand to the YieldOp of the merged island.
yield_operands.push_back(
output_island.GetYield().getOperand(output.result_index));
} }
// Create YieldOp for the new island. // Create YieldOp for the new island.
@ -234,25 +230,21 @@ void MoveInnerOpsToNewIsland(IslandOp parent, IslandOp child,
} }
// Merges two islands and places new merged island before parent or child. // Merges two islands and places new merged island before parent or child.
void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child, void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) {
IslandType insert_position) {
// Collect operands for the new merged island. // Collect operands for the new merged island.
llvm::SmallSetVector<Value*, 8> operands = llvm::SmallSetVector<Value*, 8> operands =
GetNewIslandOperands(parent, child); GetNewIslandOperands(parent, child);
// Collect results and result types for the new merged island. // Collect results for the new merged island.
llvm::SmallVector<Type, 8> result_types; llvm::SmallVector<IslandResult, 8> results =
llvm::SmallVector<Output, 8> results = GetNewIslandResultsAndForwardOutputs( GetNewIslandResultsAndForwardResults(parent, child);
&getContext(), parent, child, &result_types);
// Create the new merged island. // Create the new merged island.
IslandOp new_island = CreateNewIsland( IslandOp new_island = CreateNewIsland(parent, child, insert_position,
insert_position == IslandType::kParentIsland ? parent : child, operands.getArrayRef(), results);
result_types, operands.getArrayRef());
// Create associated YieldOp for the new merged island. // Create associated YieldOp for the new merged island.
YieldOp new_yield_op = YieldOp new_yield_op = CreateNewIslandYieldOp(new_island, results);
CreateNewIslandYieldOp(new_island, results, parent, child);
// Move inner ops from original islands into the new island. // Move inner ops from original islands into the new island.
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation()); MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
@ -270,12 +262,11 @@ void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child,
// operand must be another IslandOp for merging to take place. A new island is // operand must be another IslandOp for merging to take place. A new island is
// created and the islands being merged are removed if a merge took place. // created and the islands being merged are removed if a merge took place.
// Returns true if the island was merged with its operand. // Returns true if the island was merged with its operand.
bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) { bool MergeIslandWithOperand(IslandOp child) {
// Find candidate operand to merge island with. // Find candidate operand to merge island with.
llvm::Optional<IslandOp> candidate = GetOperandCandidateToMergeWith(child); llvm::Optional<IslandOp> candidate = GetOperandCandidateToMergeWith(child);
if (!candidate.hasValue()) return false; if (!candidate.hasValue()) return false;
auto& parent = candidate.getValue(); MergeIslands(candidate.getValue(), child, kParentIsland);
MergeIslands(parent, child, IslandType::kParentIsland);
return true; return true;
} }
@ -283,17 +274,16 @@ bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) {
// must be another IslandOp for merging to take place. A new island is created // must be another IslandOp for merging to take place. A new island is created
// and the islands being merged are removed if a merge took place. Returns true // and the islands being merged are removed if a merge took place. Returns true
// if the island was merged with its result. // if the island was merged with its result.
bool ExecutorIslandCoarsening::MergeIslandWithResult(IslandOp parent) { bool MergeIslandWithResult(IslandOp parent) {
// Find candidate result to merge island with. // Find candidate result to merge island with.
llvm::Optional<IslandOp> candidate = GetResultCandidateToMergeWith(parent); llvm::Optional<IslandOp> candidate = GetResultCandidateToMergeWith(parent);
if (!candidate.hasValue()) return false; if (!candidate.hasValue()) return false;
auto& child = candidate.getValue(); MergeIslands(parent, candidate.getValue(), kChildIsland);
MergeIslands(parent, child, IslandType::kChildIsland); return true;
return false;
} }
void ExecutorIslandCoarsening::runOnFunction() { void ExecutorIslandCoarsening::runOnFunction() {
getFunction().walk([this](GraphOp graph) { getFunction().walk([](GraphOp graph) {
Block& graph_body = graph.GetBody(); Block& graph_body = graph.GetBody();
bool updated = false; bool updated = false;
@ -323,7 +313,8 @@ std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass() {
} }
static PassRegistration<ExecutorIslandCoarsening> pass( static PassRegistration<ExecutorIslandCoarsening> pass(
"tf-executor-island-coarsening", "Merges TFExecutor dialect IslandOps"); "tf-executor-island-coarsening",
"Merges TensorFlow executor dialect IslandOps");
} // namespace tf_executor } // namespace tf_executor
} // namespace mlir } // namespace mlir