Update TPUMergeVariablesWithExecutePass to handle tf_device.parallel_execute.

It is now possible for TPU rewrite pass to generate tf_device.parallel_execute. This adds special handling for TPUExecutes in tf_device.parallel_execute, for both replicated and non replicated cases.

PiperOrigin-RevId: 301236260
Change-Id: Iae054654dc74f3447c809d0fc969a8b5fdfa545f
This commit is contained in:
Andy Ly 2020-03-16 14:17:01 -07:00 committed by TensorFlower Gardener
parent f26cea5ce6
commit 7f0f953f6b
2 changed files with 256 additions and 25 deletions

View File

@ -194,3 +194,99 @@ func @do_not_merge_multi_assign(
// CHECK-NEXT: return
return
}
// -----
// Tests that the pass merges only variable reads/writes on the same device,
// with TPUExecutes in a tf_device.parallel_execute.
// CHECK-LABEL: func @parallel_execute
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*x!tf.resource<tensor<64xf32>>>
// CHECK-SAME: %[[ARG_2:.*]]: tensor<!tf.string>
func @parallel_execute(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:1"},
%arg2: tensor<!tf.string>) {
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
// CHECK-NOT: "tf.ReadVariableOp"
// CHECK: "tf_device.parallel_execute"
%pe:2 = "tf_device.parallel_execute"() ( {
// CHECK: "tf_device.launch"
%execute0 = "tf_device.launch"() ( {
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ARG_0]], %[[ARG_2]])
%0 = "tf.TPUExecute"(%read0, %arg2) : (tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
// CHECK-NEXT: tf_device.return
tf_device.return %0 : tensor<32xf32>
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:0"
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<32xf32>
tf_device.return %execute0 : tensor<32xf32>
}, {
// CHECK: "tf_device.launch"
%execute1 = "tf_device.launch"() ( {
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ARG_1]], %[[ARG_2]])
%1 = "tf.TPUExecute"(%read1, %arg2) : (tensor<64xf32>, tensor<!tf.string>) -> tensor<64xf32>
// CHECK-NEXT: tf_device.return
tf_device.return %1 : tensor<64xf32>
// CHECK-NEXT: device = "/job:localhost/replica:0/task:0/device:TPU:1"
}) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> tensor<64xf32>
tf_device.return %execute1 : tensor<64xf32>
}) : () -> (tensor<32xf32>, tensor<64xf32>)
// CHECK-NOT: "tf.AssignVariableOp"
"tf.AssignVariableOp"(%arg0, %pe#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
"tf.AssignVariableOp"(%arg1, %pe#1) : (tensor<*x!tf.resource<tensor<64xf32>>>, tensor<64xf32>) -> ()
return
}
// -----
// Tests that the pass merges variable reads/writes for TPUExecutes in a
// tf_device.parallel_execute that is replicated (tf_device.replicate).
// CHECK-LABEL: func @replicated_parallel_execute
// CHECK-SAME: %[[ARG_0:[a-z0-9]+]]: tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-SAME: %[[ARG_1:[a-z0-9]+]]: tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-SAME: %[[ARG_2:[a-z0-9]+]]: tensor<*x!tf.resource<tensor<64xf32>>>
// CHECK-SAME: %[[ARG_3:[a-z0-9]+]]: tensor<*x!tf.resource<tensor<64xf32>>>
// CHECK-SAME: %[[ARG_4:[a-z0-9]+]]: tensor<!tf.string>
func @replicated_parallel_execute(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<*x!tf.resource<tensor<64xf32>>>,
%arg3: tensor<*x!tf.resource<tensor<64xf32>>>,
%arg4: tensor<!tf.string>) {
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]+]]: tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]+]]: tensor<*x!tf.resource<tensor<64xf32>>>
tf_device.replicate([%arg0, %arg1] as %ri0: tensor<*x!tf.resource<tensor<32xf32>>>,
[%arg2, %arg3] as %ri1: tensor<*x!tf.resource<tensor<64xf32>>>) {n = 2 : i32} {
// CHECK-NOT: "tf.ReadVariableOp"
%read0 = "tf.ReadVariableOp"(%ri0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%read1 = "tf.ReadVariableOp"(%ri1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
// CHECK: "tf_device.parallel_execute"
%pe:2 = "tf_device.parallel_execute"() ( {
// CHECK: "tf_device.launch"
%execute0 = "tf_device.launch"() ( {
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[RI_0]], %[[ARG_4]])
%0 = "tf.TPUExecute"(%read0, %arg4) : (tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
// CHECK-NEXT: tf_device.return
tf_device.return %0 : tensor<32xf32>
}) {device = ""} : () -> tensor<32xf32>
tf_device.return %execute0 : tensor<32xf32>
}, {
// CHECK: "tf_device.launch"
%execute1 = "tf_device.launch"() ( {
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[RI_1]], %[[ARG_4]])
%1 = "tf.TPUExecute"(%read1, %arg4) : (tensor<64xf32>, tensor<!tf.string>) -> tensor<64xf32>
// CHECK-NEXT: tf_device.return
tf_device.return %1 : tensor<64xf32>
}) {device = ""} : () -> tensor<64xf32>
tf_device.return %execute1 : tensor<64xf32>
}) : () -> (tensor<32xf32>, tensor<64xf32>)
// CHECK-NOT: "tf.AssignVariableOp"
"tf.AssignVariableOp"(%ri0, %pe#0) : (tensor<*x!tf.resource<tensor<32xf32>>>, tensor<32xf32>) -> ()
"tf.AssignVariableOp"(%ri1, %pe#1) : (tensor<*x!tf.resource<tensor<64xf32>>>, tensor<64xf32>) -> ()
}
return
}

View File

@ -41,7 +41,6 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -136,6 +135,10 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
// by inter-island dependencies.
Operation* first_read = nullptr;
Operation& execute = execute_launch.GetBody().front();
auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
execute_launch.getParentOp());
Operation* execute_parent =
parallel_execute ? parallel_execute.getOperation() : execute_launch;
// Find inputs that are variable reads.
for (auto operand : llvm::enumerate(execute.getOpOperands())) {
infos.new_operand_values.push_back(operand.value().get());
@ -144,9 +147,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
operand.value().get().getDefiningOp());
if (!read_op) continue;
if (check_same_region &&
read_op.getParentRegion() != execute_launch.getParentRegion()) {
read_op.getParentRegion() != execute_parent->getParentRegion())
continue;
}
auto resource = read_op.resource();
if (check_device) {
@ -193,9 +196,9 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
// work fine for the reads/assigns created by resource lifting, since they are
// placed close to the TPUExecute.
Operation* last_may_modify_resource_access_before_execute = nullptr;
for (Operation& op : llvm::reverse(
llvm::make_range(std::next(first_read->getIterator()),
execute_launch.getOperation()->getIterator()))) {
for (Operation& op :
llvm::reverse(llvm::make_range(std::next(first_read->getIterator()),
execute_parent->getIterator()))) {
if (llvm::dyn_cast<TF::ReadVariableOp>(&op)) continue;
if (!OpAccessesResource(&op)) continue;
last_may_modify_resource_access_before_execute = &op;
@ -232,10 +235,16 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
llvm::SmallPtrSet<Operation*, 8> all_assigns;
llvm::SmallVector<bool, 8> output_fused(execute_launch.getNumResults(),
false);
for (int i = 0; i < execute_launch.getNumResults(); ++i) {
auto execute_outputs =
parallel_execute
? parallel_execute.GetRegionOutputs(
execute_launch.getParentRegion()->getRegionNumber())
: execute_launch.getResults();
for (auto execute_output : llvm::enumerate(execute_outputs)) {
// TODO(lyandy): Handle updates to resource writes by remapping to parent
// launch result and checking if launch result is an AssignVariableOp.
auto result = execute_launch.getResult(i);
auto result = execute_output.value();
if (!result.hasOneUse()) continue;
auto assign_op = llvm::dyn_cast<TF::AssignVariableOp>(*result.user_begin());
if (!assign_op) continue;
@ -250,21 +259,20 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
infos.per_resource_info.shrink_and_clear();
return infos;
}
info.execute_output_index = i;
info.execute_output_index = execute_output.index();
info.assign = assign_op;
if (!last_assign || last_assign->isBeforeInBlock(assign_op)) {
last_assign = assign_op;
}
all_assigns.insert(assign_op);
output_fused[i] = true;
output_fused[execute_output.index()] = true;
}
// Check if there are other resource accesses after execute.
Operation* first_unknown_resource_access_after_execute = nullptr;
if (last_assign) {
for (auto& op : llvm::make_range(
std::next(execute_launch.getOperation()->getIterator()),
last_assign->getIterator())) {
for (auto& op : llvm::make_range(std::next(execute_parent->getIterator()),
last_assign->getIterator())) {
if (all_assigns.count(&op) > 0) continue;
if (!OpAccessesResource(&op)) continue;
first_unknown_resource_access_after_execute = &op;
@ -301,6 +309,115 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo(
return infos;
}
// Appends result types of tf_device.parallel_execute from `start` index region
// (inclusive) to `end` index region (exclusive) to `output_types` and returns
// the number of types added.
int AppendTypes(llvm::SmallVectorImpl<Type>* output_types,
tf_device::ParallelExecuteOp parallel_execute, int start,
int end) {
const int size_before = output_types->size();
for (int index = start; index < end; ++index) {
Block& block = parallel_execute.GetRegionBlockWithIndex(index);
auto terminator_operand_types = block.getTerminator()->getOperandTypes();
output_types->append(terminator_operand_types.begin(),
terminator_operand_types.end());
}
return output_types->size() - size_before;
}
// Replaces TPUExecute with TPUExecuteAndUpdateVariables in a
// tf_device.parallel_execute op.
void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute,
tf_device::LaunchOp execute_launch,
tf_device::LaunchOp merged_execute_launch,
const VariableAccessesForTPUExecute& infos,
OpBuilder* builder) {
Operation* parallel_execute_op = parallel_execute.getOperation();
// Collect result types of tf_device.parallel_execute and update region
// result types with the new merged execute result types.
llvm::SmallVector<Type, 8> output_types;
const int parallel_execute_num_results = parallel_execute_op->getNumResults();
output_types.reserve(parallel_execute_num_results);
Region* execute_region = merged_execute_launch.getParentRegion();
const int region_index = execute_region->getRegionNumber();
const int num_results_before_region =
AppendTypes(&output_types, parallel_execute, 0, region_index);
// Append updated results from merged execute.
output_types.append(merged_execute_launch.getResultTypes().begin(),
merged_execute_launch.getResultTypes().end());
const int num_regions = parallel_execute_op->getNumRegions();
const int num_results_after_region = AppendTypes(
&output_types, parallel_execute, region_index + 1, num_regions);
builder->setInsertionPoint(parallel_execute);
auto new_parallel_execute = builder->create<tf_device::ParallelExecuteOp>(
parallel_execute.getLoc(), num_regions, output_types);
// Replace the uses of the original parallel_execute before region containing
// merged execute.
Operation* new_parallel_execute_op = new_parallel_execute.getOperation();
for (int i = 0; i < num_results_before_region; ++i)
parallel_execute_op->getResult(i).replaceAllUsesWith(
new_parallel_execute_op->getResult(i));
// Replace the uses of the original parallel_execute after region containing
// merged execute. The number of results changed in the region containing the
// merged execute, but they should match, so results are replaced starting
// from the ends of both parallel_execute.
const int new_parallel_execute_num_results =
new_parallel_execute_op->getNumResults();
for (int i = 0; i < num_results_after_region; ++i)
parallel_execute_op->getResult(parallel_execute_num_results - i - 1)
.replaceAllUsesWith(new_parallel_execute_op->getResult(
new_parallel_execute_num_results - i - 1));
// Replace the uses of the original parallel_execute for the region containing
// the merged execute.
auto old_region_results = parallel_execute.GetRegionOutputs(region_index);
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
if (infos.old_to_new_output_mapping[i] < 0) continue;
old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult(
infos.old_to_new_output_mapping[i] + num_results_before_region));
}
// Replace original terminator with new terminator for returning merged
// execute results.
Operation* old_terminator = execute_region->front().getTerminator();
builder->setInsertionPointToEnd(&execute_region->front());
builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
merged_execute_launch.getResults());
old_terminator->erase();
// Remove the original TPUExecute op.
execute_launch.erase();
// Move all regions from old parallel_execute to new parallel_execute.
for (auto region : llvm::zip(new_parallel_execute_op->getRegions(),
parallel_execute_op->getRegions()))
std::get<0>(region).takeBody(std::get<1>(region));
// Remove the original parallel_execute.
parallel_execute_op->dropAllUses();
parallel_execute.erase();
}
// Replaces TPUExecute with TPUExecuteAndUpdateVariables.
void ReplaceExecute(tf_device::LaunchOp execute_launch,
tf_device::LaunchOp merged_execute_launch,
const VariableAccessesForTPUExecute& infos) {
// Replace the uses.
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
if (infos.old_to_new_output_mapping[i] < 0) continue;
execute_launch.getResult(i).replaceAllUsesWith(
merged_execute_launch.getResult(infos.old_to_new_output_mapping[i]));
}
// Remove the original TPUExecute op.
execute_launch.getOperation()->dropAllUses();
execute_launch.erase();
}
// Merges the variable accesses into one TPUExecute op.
void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
bool check_device, bool check_same_region,
@ -352,19 +469,19 @@ void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
merged_execute.getOperation()->moveBefore(
merged_execute_launch.GetBody().getTerminator());
// Replace the uses.
for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) {
if (infos.old_to_new_output_mapping[i] < 0) continue;
execute_launch.getResult(i).replaceAllUsesWith(
merged_execute_launch.getResult(infos.old_to_new_output_mapping[i]));
}
if (auto parallel_execute = llvm::dyn_cast<tf_device::ParallelExecuteOp>(
execute_launch.getParentOp()))
ReplaceParallelExecute(parallel_execute, execute_launch,
merged_execute_launch, infos, builder);
else
ReplaceExecute(execute_launch, merged_execute_launch, infos);
// Remove the assign ops.
for (const auto& entry : infos.per_resource_info) {
const auto& info = entry.getSecond();
if (info.assign) info.assign->erase();
}
// Remove the original TPUExecute op.
execute_launch.erase();
// Remove the read ops if they have no more uses.
for (const auto& entry : infos.per_resource_info) {
const auto& info = entry.getSecond();
@ -372,25 +489,43 @@ void MergeForOneTPUExecute(tf_device::LaunchOp execute_launch,
}
}
// Checks if an ops parent is a tf_device.parallel_execute and the region the
// op is in is perfectly wrapped.
bool ParentParallelExecuteWrapsSingleOp(Operation* op) {
auto parallel_execute =
llvm::dyn_cast<tf_device::ParallelExecuteOp>(op->getParentOp());
if (!parallel_execute) return true;
return parallel_execute.RegionWrapsSingleOp(
op->getParentRegion()->getRegionNumber());
}
void TPUMergeVariablesWithExecutePass::runOnFunction() {
// Find all the executes first, since we will mutate the nodes around each
// execute.
llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
getFunction().walk([&](tf_device::LaunchOp op) {
if (op.WrapsSingleOp() && llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()))
if (op.WrapsSingleOp() &&
llvm::isa<TF::TPUExecuteOp>(op.GetBody().front()) &&
ParentParallelExecuteWrapsSingleOp(op))
execute_launches.push_back(op);
});
for (auto execute_launch : execute_launches) {
OpBuilder builder(&getContext());
const bool parent_is_replicate =
llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp());
llvm::isa<tf_device::ReplicateOp>(execute_launch.getParentOp()) ||
(llvm::isa<tf_device::ParallelExecuteOp>(
execute_launch.getParentOp()) &&
llvm::isa<tf_device::ReplicateOp>(
execute_launch.getParentOp()->getParentOp()));
// If this is inside a tf_device::ReplicateOp, the variables are guaranteed
// to be on the same device as the TPUExecute op. Skip device checking in
// that case, but we need to check that we are only merging reads/assigns
// that are also in this replicated region.
MergeForOneTPUExecute(execute_launch, !parent_is_replicate,
parent_is_replicate, &builder);
MergeForOneTPUExecute(execute_launch, /*check_device=*/!parent_is_replicate,
/*check_same_region=*/parent_is_replicate, &builder);
}
}