Cleanup TensorFlow executor dialect island coarsening pass to match TPU cluster formation pass style more (NFC).

Updated functions to not return an output and have output parameters. Replaced `output` with `result`. Switched from storing island type and result index to results directly.

PiperOrigin-RevId: 266509565
This commit is contained in:
Andy Ly 2019-08-30 21:51:06 -07:00 committed by TensorFlower Gardener
parent 5025675366
commit 98bd5a6e22

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass takes TFExecutor dialect IslandOps and merges them.
// Note, this currently does not handle TensorFlow V1 style control flow/frames
// or side effecting ops yet.
// This transformation pass takes TensorFlow executor dialect IslandOps and
// merges them. Note, this currently does not handle TensorFlow V1 style control
// flow/frames or side effecting ops yet.
#include <iterator>
#include <tuple>
@ -46,25 +46,19 @@ namespace {
// merging another island or is the island (child) being being merged.
enum IslandType { kParentIsland, kChildIsland };
// Output is a helper struct holding a result index and island type (parent or
// child).
struct Output {
Output(IslandType island_type, int result_index)
: island_type(island_type), result_index(result_index) {}
// IslandResult is a helper struct holding an islands result and associated
// inner op result.
struct IslandResult {
IslandResult(Value* inner_op_result, Value* island_result)
: inner_op_result(inner_op_result), island_result(island_result) {}
IslandType island_type;
int result_index;
Value* inner_op_result;
Value* island_result;
};
struct ExecutorIslandCoarsening
: public FunctionPass<ExecutorIslandCoarsening> {
void runOnFunction() override;
private:
void MergeIslands(IslandOp parent, IslandOp child,
IslandType insert_position);
bool MergeIslandWithOperand(IslandOp child);
bool MergeIslandWithResult(IslandOp parent);
};
// Finds the operation leading to an island that the island can be merged with.
@ -97,7 +91,7 @@ llvm::Optional<IslandOp> GetOperandCandidateToMergeWith(IslandOp island) {
}
// Finds the operation leading from an island that the island can be merged
// with. This looks for the operation, either control output or data output to
// with. This looks for the operation, either control result or data result to
// an op, that is closest to the island in the graph. If no candidate can be
// found or the op found is not an island, an empty optional is returned.
llvm::Optional<IslandOp> GetResultCandidateToMergeWith(IslandOp island) {
@ -136,54 +130,60 @@ llvm::SmallSetVector<Value*, 8> GetNewIslandOperands(IslandOp parent,
return operands;
}
// Collects the results for the new island by going through each data output of
// Collects the results for the new island by going through each data result of
// the islands being merged. Unused results outside of the merged island to be
// formed are pruned. If the child island inner ops consume the parent island
// control output, the child island inner ops will have that respective control
// control result, the child island inner ops will have that respective control
// input pruned. Results of the parent island that are consumed by the child
// island are replaced by the respective inner ops output from the parent
// island are replaced by the respective inner ops result from the parent
// island.
llvm::SmallVector<Output, 8> GetNewIslandResultsAndForwardOutputs(
mlir::MLIRContext* context, IslandOp parent, IslandOp child,
llvm::SmallVector<Type, 8>* result_types) {
llvm::SmallVector<Output, 8> results;
llvm::SmallVector<IslandResult, 8> GetNewIslandResultsAndForwardResults(
IslandOp parent, IslandOp child) {
llvm::SmallVector<IslandResult, 8> results;
YieldOp yield_op = parent.GetYield();
Block& child_body = child.GetBody();
for (auto& ret_and_idx : llvm::enumerate(parent.outputs())) {
bool output_captured = false;
Value* yield_input = yield_op.getOperand(ret_and_idx.index());
for (auto& use :
llvm::make_early_inc_range(ret_and_idx.value()->getUses())) {
for (auto ret_vals :
llvm::zip(parent.GetYield().getOperands(), parent.outputs())) {
bool result_captured = false;
Value* inner_op_result = std::get<0>(ret_vals);
Value* island_result = std::get<1>(ret_vals);
for (auto& use : llvm::make_early_inc_range(island_result->getUses())) {
if (child_body.findAncestorInstInBlock(*use.getOwner())) {
// Forward output from inner op.
use.set(yield_input);
} else if (!output_captured) {
results.push_back(
Output(IslandType::kParentIsland, ret_and_idx.index()));
result_types->push_back(ret_and_idx.value()->getType());
output_captured = true;
// Forward result from inner op.
use.set(inner_op_result);
} else if (!result_captured) {
results.emplace_back(inner_op_result, island_result);
result_captured = true;
}
}
}
for (auto& ret_and_idx : llvm::enumerate(child.outputs())) {
if (!ret_and_idx.value()->use_empty()) {
results.push_back(Output(IslandType::kChildIsland, ret_and_idx.index()));
result_types->push_back(ret_and_idx.value()->getType());
for (auto ret_vals :
llvm::zip(child.GetYield().getOperands(), child.outputs())) {
Value* inner_op_result = std::get<0>(ret_vals);
Value* island_result = std::get<1>(ret_vals);
if (!island_result->use_empty()) {
results.emplace_back(inner_op_result, island_result);
}
}
// IslandOps always have a control output.
result_types->push_back(ControlType::get(context));
return results;
}
// Creates the new merged island.
IslandOp CreateNewIsland(Operation* old_island,
llvm::ArrayRef<Type> result_types,
llvm::ArrayRef<Value*> operands) {
IslandOp CreateNewIsland(IslandOp parent, IslandOp child,
IslandType insert_position,
llvm::ArrayRef<Value*> operands,
llvm::ArrayRef<IslandResult> results) {
// Collect types from results.
llvm::SmallVector<Type, 8> result_types;
for (const auto& result : results)
result_types.push_back(result.inner_op_result->getType());
// IslandOps always have a control result.
result_types.push_back(ControlType::get(parent.getContext()));
Operation* old_island = insert_position == kParentIsland ? parent : child;
OpBuilder builder(old_island);
auto new_island = builder.create<IslandOp>(
old_island->getLoc(), result_types, operands, ArrayRef<NamedAttribute>{});
@ -193,22 +193,18 @@ IslandOp CreateNewIsland(Operation* old_island,
// Creates respective YieldOp for the new merged island.
YieldOp CreateNewIslandYieldOp(IslandOp new_island,
llvm::ArrayRef<Output> results, IslandOp parent,
IslandOp child) {
llvm::ArrayRef<IslandResult> results) {
llvm::SmallVector<Value*, 8> yield_operands;
yield_operands.reserve(results.size());
for (auto ret_vals : llvm::zip(results, new_island.outputs())) {
// Get consumed output (island type and result index).
const auto& output = std::get<0>(ret_vals);
IslandOp& output_island =
output.island_type == IslandType::kParentIsland ? parent : child;
Value* result = output_island.getResult(output.result_index);
// Replace original result with new island result.
result->replaceAllUsesWith(std::get<1>(ret_vals));
// Find YieldOp in original island, grab the associated operand (inner op
// output) and add it as a operand to the YieldOp of the merged island.
yield_operands.push_back(
output_island.GetYield().getOperand(output.result_index));
const auto& old_result = std::get<0>(ret_vals);
// Replace original island result with new island result.
old_result.island_result->replaceAllUsesWith(std::get<1>(ret_vals));
// Add associated inner op result to operands of the YieldOp.
yield_operands.push_back(old_result.inner_op_result);
}
// Create YieldOp for the new island.
@ -234,25 +230,21 @@ void MoveInnerOpsToNewIsland(IslandOp parent, IslandOp child,
}
// Merges two islands and places new merged island before parent or child.
void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child,
IslandType insert_position) {
void MergeIslands(IslandOp parent, IslandOp child, IslandType insert_position) {
// Collect operands for the new merged island.
llvm::SmallSetVector<Value*, 8> operands =
GetNewIslandOperands(parent, child);
// Collect results and result types for the new merged island.
llvm::SmallVector<Type, 8> result_types;
llvm::SmallVector<Output, 8> results = GetNewIslandResultsAndForwardOutputs(
&getContext(), parent, child, &result_types);
// Collect results for the new merged island.
llvm::SmallVector<IslandResult, 8> results =
GetNewIslandResultsAndForwardResults(parent, child);
// Create the new merged island.
IslandOp new_island = CreateNewIsland(
insert_position == IslandType::kParentIsland ? parent : child,
result_types, operands.getArrayRef());
IslandOp new_island = CreateNewIsland(parent, child, insert_position,
operands.getArrayRef(), results);
// Create associated YieldOp for the new merged island.
YieldOp new_yield_op =
CreateNewIslandYieldOp(new_island, results, parent, child);
YieldOp new_yield_op = CreateNewIslandYieldOp(new_island, results);
// Move inner ops from original islands into the new island.
MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation());
@ -270,12 +262,11 @@ void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child,
// operand must be another IslandOp for merging to take place. A new island is
// created and the islands being merged are removed if a merge took place.
// Returns true if the island was merged with its operand.
bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) {
bool MergeIslandWithOperand(IslandOp child) {
// Find candidate operand to merge island with.
llvm::Optional<IslandOp> candidate = GetOperandCandidateToMergeWith(child);
if (!candidate.hasValue()) return false;
auto& parent = candidate.getValue();
MergeIslands(parent, child, IslandType::kParentIsland);
MergeIslands(candidate.getValue(), child, kParentIsland);
return true;
}
@ -283,17 +274,16 @@ bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) {
// must be another IslandOp for merging to take place. A new island is created
// and the islands being merged are removed if a merge took place. Returns true
// if the island was merged with its result.
bool ExecutorIslandCoarsening::MergeIslandWithResult(IslandOp parent) {
bool MergeIslandWithResult(IslandOp parent) {
// Find candidate result to merge island with.
llvm::Optional<IslandOp> candidate = GetResultCandidateToMergeWith(parent);
if (!candidate.hasValue()) return false;
auto& child = candidate.getValue();
MergeIslands(parent, child, IslandType::kChildIsland);
return false;
MergeIslands(parent, candidate.getValue(), kChildIsland);
return true;
}
void ExecutorIslandCoarsening::runOnFunction() {
getFunction().walk([this](GraphOp graph) {
getFunction().walk([](GraphOp graph) {
Block& graph_body = graph.GetBody();
bool updated = false;
@ -323,7 +313,8 @@ std::unique_ptr<FunctionPassBase> CreateTFExecutorIslandCoarseningPass() {
}
static PassRegistration<ExecutorIslandCoarsening> pass(
"tf-executor-island-coarsening", "Merges TFExecutor dialect IslandOps");
"tf-executor-island-coarsening",
"Merges TensorFlow executor dialect IslandOps");
} // namespace tf_executor
} // namespace mlir